mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Fixed big issue with bucketing dataloader and added random cripping to a point of interest
This commit is contained in:
@@ -12,7 +12,7 @@ import torch
|
||||
import torch.backends.cuda
|
||||
|
||||
from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch = next(dataloader_iterator_reg)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator_reg = iter(dataloader_reg)
|
||||
trigger_dataloader_setup_epoch(dataloader_reg)
|
||||
batch = next(dataloader_iterator_reg)
|
||||
self.progress_bar.unpause()
|
||||
is_reg_step = True
|
||||
elif dataloader is not None:
|
||||
try:
|
||||
batch = next(dataloader_iterator)
|
||||
except StopIteration:
|
||||
# hit the end of an epoch, reset
|
||||
self.progress_bar.pause()
|
||||
dataloader_iterator = iter(dataloader)
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
batch = next(dataloader_iterator)
|
||||
self.progress_bar.unpause()
|
||||
else:
|
||||
batch = None
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import sys
|
||||
import os
|
||||
@@ -16,12 +17,14 @@ sys.path.append(SD_SCRIPTS_ROOT)
|
||||
|
||||
from library.model_util import load_vae
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
|
||||
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
|
||||
trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
@@ -34,38 +37,44 @@ batch_size = 4
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_ext='txt',
|
||||
caption_ext='json',
|
||||
default_caption='default',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
augments=['ColorJitter', 'RandomEqualize'],
|
||||
augments=['ColorJitter'],
|
||||
poi='person'
|
||||
|
||||
)
|
||||
|
||||
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
|
||||
|
||||
|
||||
# run through an epoch ang check sizes
|
||||
for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
dataloader_iterator = iter(dataloader)
|
||||
for epoch in range(args.epochs):
|
||||
for batch in dataloader:
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
chunks = torch.chunk(img_batch, batch_size, dim=0)
|
||||
# put them so they are size by side
|
||||
big_img = torch.cat(chunks, dim=3)
|
||||
big_img = big_img.squeeze(0)
|
||||
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
min_val = big_img.min()
|
||||
max_val = big_img.max()
|
||||
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
big_img = (big_img / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
# convert to image
|
||||
img = transforms.ToPILImage()(big_img)
|
||||
|
||||
show_img(img)
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
time.sleep(1.0)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
@@ -234,7 +234,8 @@ class DatasetConfig:
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
|
||||
|
||||
# cache latents will store them in memory
|
||||
self.cache_latents: bool = kwargs.get('cache_latents', False)
|
||||
|
||||
@@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
|
||||
self.is_caching_latents_to_memory = dataset_config.cache_latents
|
||||
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
|
||||
self.epoch_num = 0
|
||||
|
||||
self.sd = sd
|
||||
|
||||
@@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
self.setup_epoch()
|
||||
|
||||
def setup_epoch(self):
|
||||
# TODO: set this up to redo cropping and everything else
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
if self.epoch_num == 0:
|
||||
# initial setup
|
||||
# do not call for now
|
||||
if self.dataset_config.buckets:
|
||||
# setup buckets
|
||||
self.setup_buckets()
|
||||
if self.is_caching_latents:
|
||||
self.cache_latents_all_latents()
|
||||
else:
|
||||
if self.dataset_config.poi is not None:
|
||||
# handle cropping to a specific point of interest
|
||||
# setup buckets every epoch
|
||||
self.setup_buckets(quiet=True)
|
||||
self.epoch_num += 1
|
||||
|
||||
def __len__(self):
|
||||
if self.dataset_config.buckets:
|
||||
@@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
|
||||
# for buckets we collate ourselves for now
|
||||
# todo allow a scheduler to dynamically make buckets
|
||||
# we collate ourselves
|
||||
if len(self.batch_indices) - 1 < item:
|
||||
# tried everything to solve this. No way to reset length when redoing things. Pick another index
|
||||
item = random.randint(0, len(self.batch_indices) - 1)
|
||||
idx_list = self.batch_indices[item]
|
||||
return [self._get_single_item(idx) for idx in idx_list]
|
||||
else:
|
||||
@@ -523,3 +534,27 @@ def get_dataloader_from_datasets(
|
||||
collate_fn=dto_collation
|
||||
)
|
||||
return data_loader
|
||||
|
||||
|
||||
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
|
||||
# hacky but needed because of different types of datasets and dataloaders
|
||||
dataloader.len = None
|
||||
if isinstance(dataloader.dataset, list):
|
||||
for dataset in dataloader.dataset:
|
||||
if hasattr(dataset, 'datasets'):
|
||||
for sub_dataset in dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
elif hasattr(dataset, 'setup_epoch'):
|
||||
dataset.setup_epoch()
|
||||
dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'setup_epoch'):
|
||||
dataloader.dataset.setup_epoch()
|
||||
dataloader.dataset.len = None
|
||||
elif hasattr(dataloader.dataset, 'datasets'):
|
||||
dataloader.dataset.len = None
|
||||
for sub_dataset in dataloader.dataset.datasets:
|
||||
if hasattr(sub_dataset, 'setup_epoch'):
|
||||
sub_dataset.setup_epoch()
|
||||
sub_dataset.len = None
|
||||
|
||||
@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
|
||||
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -27,6 +27,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
PoiFileItemDTOMixin,
|
||||
ArgBreakMixin,
|
||||
):
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -70,20 +71,25 @@ class FileItemDTO(
|
||||
|
||||
class DataLoaderBatchDTO:
|
||||
def __init__(self, **kwargs):
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if we have encoded latents, we concatenate them
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
if self.file_items[0].control_tensor is not None:
|
||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||
try:
|
||||
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
|
||||
is_latents_cached = self.file_items[0].is_latent_cached
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if not is_latents_cached:
|
||||
# only return a tensor if latents are not cached
|
||||
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
|
||||
# if we have encoded latents, we concatenate them
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
if is_latents_cached:
|
||||
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
if self.file_items[0].control_tensor is not None:
|
||||
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
|
||||
def get_is_reg_list(self):
|
||||
return [x.is_reg for x in self.file_items]
|
||||
|
||||
@@ -29,11 +29,13 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
transforms_dict = {
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01),
|
||||
'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
|
||||
'RandomEqualize': transforms.RandomEqualize(p=0.2),
|
||||
}
|
||||
|
||||
caption_ext_list = ['txt', 'json', 'caption']
|
||||
|
||||
|
||||
class CaptionMixin:
|
||||
def get_caption_item(self: 'AiToolkitDataset', index):
|
||||
if not hasattr(self, 'caption_type'):
|
||||
@@ -106,66 +108,96 @@ class BucketsMixin:
|
||||
self.batch_indices: List[List[int]] = []
|
||||
|
||||
def build_batch_indices(self: 'AiToolkitDataset'):
|
||||
self.batch_indices = []
|
||||
for key, bucket in self.buckets.items():
|
||||
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
|
||||
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
|
||||
batch = bucket.file_list_idx[start_idx:end_idx]
|
||||
self.batch_indices.append(batch)
|
||||
|
||||
def setup_buckets(self: 'AiToolkitDataset'):
|
||||
def setup_buckets(self: 'AiToolkitDataset', quiet=False):
|
||||
if not hasattr(self, 'file_list'):
|
||||
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
|
||||
if not hasattr(self, 'dataset_config'):
|
||||
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
|
||||
|
||||
if self.epoch_num > 0 and self.dataset_config.poi is None:
|
||||
# no need to rebuild buckets for now
|
||||
# todo handle random cropping for buckets
|
||||
return
|
||||
self.buckets = {} # clear it
|
||||
|
||||
config: 'DatasetConfig' = self.dataset_config
|
||||
resolution = config.resolution
|
||||
bucket_tolerance = config.bucket_tolerance
|
||||
file_list: List['FileItemDTO'] = self.file_list
|
||||
|
||||
total_pixels = resolution * resolution
|
||||
|
||||
# for file_item in enumerate(file_list):
|
||||
for idx, file_item in enumerate(file_list):
|
||||
file_item: 'FileItemDTO' = file_item
|
||||
width = file_item.crop_width
|
||||
height = file_item.crop_height
|
||||
width = int(file_item.width * file_item.dataset_config.scale)
|
||||
height = int(file_item.height * file_item.dataset_config.scale)
|
||||
|
||||
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution,
|
||||
divisibility=bucket_tolerance)
|
||||
|
||||
# set the scaling height and with to match smallest size, and keep aspect ratio
|
||||
if width > height:
|
||||
file_item.scale_to_height = bucket_resolution["height"]
|
||||
file_item.scale_to_width = int(width * (bucket_resolution["height"] / height))
|
||||
if file_item.has_point_of_interest:
|
||||
# let the poi module handle the bucketing
|
||||
file_item.setup_poi_bucket()
|
||||
else:
|
||||
file_item.scale_to_width = bucket_resolution["width"]
|
||||
file_item.scale_to_height = int(height * (bucket_resolution["width"] / width))
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
width, height,
|
||||
resolution=resolution,
|
||||
divisibility=bucket_tolerance
|
||||
)
|
||||
|
||||
file_item.crop_height = bucket_resolution["height"]
|
||||
file_item.crop_width = bucket_resolution["width"]
|
||||
# Calculate scale factors for width and height
|
||||
width_scale_factor = bucket_resolution["width"] / width
|
||||
height_scale_factor = bucket_resolution["height"] / height
|
||||
|
||||
new_width = bucket_resolution["width"]
|
||||
new_height = bucket_resolution["height"]
|
||||
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
||||
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
||||
|
||||
file_item.scale_to_width = int(width * max_scale_factor)
|
||||
file_item.scale_to_height = int(height * max_scale_factor)
|
||||
|
||||
file_item.crop_height = bucket_resolution["height"]
|
||||
file_item.crop_width = bucket_resolution["width"]
|
||||
|
||||
new_width = bucket_resolution["width"]
|
||||
new_height = bucket_resolution["height"]
|
||||
|
||||
if self.dataset_config.random_crop:
|
||||
# random crop
|
||||
crop_x = random.randint(0, file_item.scale_to_width - new_width)
|
||||
crop_y = random.randint(0, file_item.scale_to_height - new_height)
|
||||
file_item.crop_x = crop_x
|
||||
file_item.crop_y = crop_y
|
||||
else:
|
||||
# do central crop
|
||||
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
|
||||
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
|
||||
|
||||
if file_item.crop_y < 0 or file_item.crop_x < 0:
|
||||
print('debug')
|
||||
|
||||
# check if bucket exists, if not, create it
|
||||
bucket_key = f'{new_width}x{new_height}'
|
||||
bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
|
||||
if bucket_key not in self.buckets:
|
||||
self.buckets[bucket_key] = Bucket(new_width, new_height)
|
||||
self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height)
|
||||
self.buckets[bucket_key].file_list_idx.append(idx)
|
||||
|
||||
# print the buckets
|
||||
self.build_batch_indices()
|
||||
name = f"{os.path.basename(self.dataset_path)} ({self.resolution})"
|
||||
print(f'Bucket sizes for {self.dataset_path}:')
|
||||
for key, bucket in self.buckets.items():
|
||||
print(f'{key}: {len(bucket.file_list_idx)} files')
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
if not quiet:
|
||||
print(f'Bucket sizes for {self.dataset_path}:')
|
||||
for key, bucket in self.buckets.items():
|
||||
print(f'{key}: {len(bucket.file_list_idx)} files')
|
||||
print(f'{len(self.buckets)} buckets made')
|
||||
|
||||
# file buckets made
|
||||
|
||||
|
||||
class CaptionProcessingDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# todo allow for loading from sd-scripts style dict
|
||||
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
|
||||
@@ -281,10 +313,19 @@ class ImageProcessingDTOMixin:
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# todo allow scaling and cropping, will be hard to add
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
|
||||
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
|
||||
print('size mismatch')
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
else:
|
||||
# Downscale the source image first
|
||||
# TODO this is nto right
|
||||
@@ -371,7 +412,14 @@ class ControlFileItemDTOMixin:
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
@@ -381,6 +429,93 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
|
||||
|
||||
class PoiFileItemDTOMixin:
|
||||
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
|
||||
# items in the poi will always be inside the image when random cropping
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
# poi is a name of the box point of interest in the caption json file
|
||||
dataset_config = kwargs.get('dataset_config', None)
|
||||
path = kwargs.get('path', None)
|
||||
self.poi: Union[str, None] = dataset_config.poi
|
||||
self.has_point_of_interest = self.poi is not None
|
||||
self.poi_x: Union[int, None] = None
|
||||
self.poi_y: Union[int, None] = None
|
||||
self.poi_width: Union[int, None] = None
|
||||
self.poi_height: Union[int, None] = None
|
||||
|
||||
if self.poi is not None:
|
||||
# make sure latent caching is off
|
||||
if dataset_config.cache_latents or dataset_config.cache_latents_to_disk:
|
||||
raise Exception(
|
||||
f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config"
|
||||
)
|
||||
# make sure we are loading through json
|
||||
if dataset_config.caption_ext != 'json':
|
||||
raise Exception(
|
||||
f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config"
|
||||
)
|
||||
self.poi = self.poi.strip()
|
||||
# get the caption path
|
||||
file_path_no_ext = os.path.splitext(path)[0]
|
||||
caption_path = file_path_no_ext + '.json'
|
||||
if not os.path.exists(caption_path):
|
||||
raise Exception(f"Error: caption file not found for poi: {caption_path}")
|
||||
with open(caption_path, 'r', encoding='utf-8') as f:
|
||||
json_data = json.load(f)
|
||||
if 'poi' not in json_data:
|
||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
||||
if self.poi not in json_data['poi']:
|
||||
raise Exception(f"Error: poi not found in caption file: {caption_path}")
|
||||
# poi has, x, y, width, height
|
||||
poi = json_data['poi'][self.poi]
|
||||
self.poi_x = int(poi['x'])
|
||||
self.poi_y = int(poi['y'])
|
||||
self.poi_width = int(poi['width'])
|
||||
self.poi_height = int(poi['height'])
|
||||
|
||||
def setup_poi_bucket(self: 'FileItemDTO'):
|
||||
# we are using poi, so we need to calculate the bucket based on the poi
|
||||
|
||||
resolution = self.dataset_config.resolution
|
||||
bucket_tolerance = self.dataset_config.bucket_tolerance
|
||||
initial_width = int(self.width * self.dataset_config.scale)
|
||||
initial_height = int(self.height * self.dataset_config.scale)
|
||||
poi_x = int(self.poi_x * self.dataset_config.scale)
|
||||
poi_y = int(self.poi_y * self.dataset_config.scale)
|
||||
poi_width = int(self.poi_width * self.dataset_config.scale)
|
||||
poi_height = int(self.poi_height * self.dataset_config.scale)
|
||||
|
||||
# todo handle a poi that is smaller than resolution
|
||||
# determine new cropping
|
||||
crop_left = random.randint(0, poi_x)
|
||||
crop_right = random.randint(poi_x + poi_width, initial_width)
|
||||
crop_top = random.randint(0, poi_y)
|
||||
crop_bottom = random.randint(poi_y + poi_height, initial_height)
|
||||
|
||||
new_width = crop_right - crop_left
|
||||
new_height = crop_bottom - crop_top
|
||||
|
||||
bucket_resolution = get_bucket_for_image_size(
|
||||
new_width, new_height,
|
||||
resolution=resolution,
|
||||
divisibility=bucket_tolerance
|
||||
)
|
||||
|
||||
width_scale_factor = bucket_resolution["width"] / new_width
|
||||
height_scale_factor = bucket_resolution["height"] / new_height
|
||||
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
|
||||
max_scale_factor = max(width_scale_factor, height_scale_factor)
|
||||
|
||||
self.scale_to_width = int(initial_width * max_scale_factor)
|
||||
self.scale_to_height = int(initial_height * max_scale_factor)
|
||||
self.crop_width = bucket_resolution['width']
|
||||
self.crop_height = bucket_resolution['height']
|
||||
self.crop_x = int(crop_left * max_scale_factor)
|
||||
self.crop_y = int(crop_top * max_scale_factor)
|
||||
|
||||
|
||||
class ArgBreakMixin:
|
||||
# just stops super calls form hitting object
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user