Fixed big issue with bucketing dataloader and added random cripping to a point of interest

This commit is contained in:
Jaret Burkett
2023-10-02 18:31:08 -06:00
parent 320e109c5f
commit 579650eaf8
6 changed files with 264 additions and 72 deletions

View File

@@ -12,7 +12,7 @@ import torch
import torch.backends.cuda import torch.backends.cuda
from toolkit.basic import value_map from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
@@ -931,16 +931,22 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch = next(dataloader_iterator_reg) batch = next(dataloader_iterator_reg)
except StopIteration: except StopIteration:
# hit the end of an epoch, reset # hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator_reg = iter(dataloader_reg) dataloader_iterator_reg = iter(dataloader_reg)
trigger_dataloader_setup_epoch(dataloader_reg)
batch = next(dataloader_iterator_reg) batch = next(dataloader_iterator_reg)
self.progress_bar.unpause()
is_reg_step = True is_reg_step = True
elif dataloader is not None: elif dataloader is not None:
try: try:
batch = next(dataloader_iterator) batch = next(dataloader_iterator)
except StopIteration: except StopIteration:
# hit the end of an epoch, reset # hit the end of an epoch, reset
self.progress_bar.pause()
dataloader_iterator = iter(dataloader) dataloader_iterator = iter(dataloader)
trigger_dataloader_setup_epoch(dataloader)
batch = next(dataloader_iterator) batch = next(dataloader_iterator)
self.progress_bar.unpause()
else: else:
batch = None batch = None

View File

@@ -2,6 +2,7 @@ import time
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
import sys import sys
import os import os
@@ -16,12 +17,14 @@ sys.path.append(SD_SCRIPTS_ROOT)
from library.model_util import load_vae from library.model_util import load_vae
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, \
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input') parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
@@ -34,38 +37,44 @@ batch_size = 4
dataset_config = DatasetConfig( dataset_config = DatasetConfig(
dataset_path=dataset_folder, dataset_path=dataset_folder,
resolution=resolution, resolution=resolution,
caption_ext='txt', caption_ext='json',
default_caption='default', default_caption='default',
buckets=True, buckets=True,
bucket_tolerance=bucket_tolerance, bucket_tolerance=bucket_tolerance,
augments=['ColorJitter', 'RandomEqualize'], augments=['ColorJitter'],
poi='person'
) )
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes # run through an epoch ang check sizes
for batch in dataloader: dataloader_iterator = iter(dataloader)
batch: 'DataLoaderBatchDTO' for epoch in range(args.epochs):
img_batch = batch.tensor for batch in dataloader:
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
chunks = torch.chunk(img_batch, batch_size, dim=0) chunks = torch.chunk(img_batch, batch_size, dim=0)
# put them so they are size by side # put them so they are size by side
big_img = torch.cat(chunks, dim=3) big_img = torch.cat(chunks, dim=3)
big_img = big_img.squeeze(0) big_img = big_img.squeeze(0)
min_val = big_img.min() min_val = big_img.min()
max_val = big_img.max() max_val = big_img.max()
big_img = (big_img / 2 + 0.5).clamp(0, 1) big_img = (big_img / 2 + 0.5).clamp(0, 1)
# convert to image # convert to image
img = transforms.ToPILImage()(big_img) img = transforms.ToPILImage()(big_img)
show_img(img) show_img(img)
time.sleep(1.0) time.sleep(1.0)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)
cv2.destroyAllWindows() cv2.destroyAllWindows()

View File

@@ -234,7 +234,8 @@ class DatasetConfig:
self.flip_x: bool = kwargs.get('flip_x', False) self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False) self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', []) self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
self.poi: Union[str, None] = kwargs.get('poi', None) # if one is set and in json data, will be used as auto crop scale point of interes
# cache latents will store them in memory # cache latents will store them in memory
self.cache_latents: bool = kwargs.get('cache_latents', False) self.cache_latents: bool = kwargs.get('cache_latents', False)

View File

@@ -346,6 +346,7 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk self.is_caching_latents = dataset_config.cache_latents or dataset_config.cache_latents_to_disk
self.is_caching_latents_to_memory = dataset_config.cache_latents self.is_caching_latents_to_memory = dataset_config.cache_latents
self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk self.is_caching_latents_to_disk = dataset_config.cache_latents_to_disk
self.epoch_num = 0
self.sd = sd self.sd = sd
@@ -426,13 +427,20 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
self.setup_epoch() self.setup_epoch()
def setup_epoch(self): def setup_epoch(self):
# TODO: set this up to redo cropping and everything else if self.epoch_num == 0:
# do not call for now # initial setup
if self.dataset_config.buckets: # do not call for now
# setup buckets if self.dataset_config.buckets:
self.setup_buckets() # setup buckets
if self.is_caching_latents: self.setup_buckets()
self.cache_latents_all_latents() if self.is_caching_latents:
self.cache_latents_all_latents()
else:
if self.dataset_config.poi is not None:
# handle cropping to a specific point of interest
# setup buckets every epoch
self.setup_buckets(quiet=True)
self.epoch_num += 1
def __len__(self): def __len__(self):
if self.dataset_config.buckets: if self.dataset_config.buckets:
@@ -450,6 +458,9 @@ class AiToolkitDataset(LatentCachingMixin, BucketsMixin, CaptionMixin, Dataset):
# for buckets we collate ourselves for now # for buckets we collate ourselves for now
# todo allow a scheduler to dynamically make buckets # todo allow a scheduler to dynamically make buckets
# we collate ourselves # we collate ourselves
if len(self.batch_indices) - 1 < item:
# tried everything to solve this. No way to reset length when redoing things. Pick another index
item = random.randint(0, len(self.batch_indices) - 1)
idx_list = self.batch_indices[item] idx_list = self.batch_indices[item]
return [self._get_single_item(idx) for idx in idx_list] return [self._get_single_item(idx) for idx in idx_list]
else: else:
@@ -523,3 +534,27 @@ def get_dataloader_from_datasets(
collate_fn=dto_collation collate_fn=dto_collation
) )
return data_loader return data_loader
def trigger_dataloader_setup_epoch(dataloader: DataLoader):
# hacky but needed because of different types of datasets and dataloaders
dataloader.len = None
if isinstance(dataloader.dataset, list):
for dataset in dataloader.dataset:
if hasattr(dataset, 'datasets'):
for sub_dataset in dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None
elif hasattr(dataset, 'setup_epoch'):
dataset.setup_epoch()
dataset.len = None
elif hasattr(dataloader.dataset, 'setup_epoch'):
dataloader.dataset.setup_epoch()
dataloader.dataset.len = None
elif hasattr(dataloader.dataset, 'datasets'):
dataloader.dataset.len = None
for sub_dataset in dataloader.dataset.datasets:
if hasattr(sub_dataset, 'setup_epoch'):
sub_dataset.setup_epoch()
sub_dataset.len = None

View File

@@ -7,7 +7,7 @@ from PIL.ImageOps import exif_transpose
from toolkit import image_utils from toolkit import image_utils
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \ from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
ControlFileItemDTOMixin, ArgBreakMixin ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
@@ -27,6 +27,7 @@ class FileItemDTO(
CaptionProcessingDTOMixin, CaptionProcessingDTOMixin,
ImageProcessingDTOMixin, ImageProcessingDTOMixin,
ControlFileItemDTOMixin, ControlFileItemDTOMixin,
PoiFileItemDTOMixin,
ArgBreakMixin, ArgBreakMixin,
): ):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -70,20 +71,25 @@ class FileItemDTO(
class DataLoaderBatchDTO: class DataLoaderBatchDTO:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None) try:
is_latents_cached = self.file_items[0].is_latent_cached self.file_items: List['FileItemDTO'] = kwargs.get('file_items', None)
self.tensor: Union[torch.Tensor, None] = None is_latents_cached = self.file_items[0].is_latent_cached
self.latents: Union[torch.Tensor, None] = None self.tensor: Union[torch.Tensor, None] = None
if not is_latents_cached: self.latents: Union[torch.Tensor, None] = None
# only return a tensor if latents are not cached if not is_latents_cached:
self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items]) # only return a tensor if latents are not cached
# if we have encoded latents, we concatenate them self.tensor: torch.Tensor = torch.cat([x.tensor.unsqueeze(0) for x in self.file_items])
self.latents: Union[torch.Tensor, None] = None # if we have encoded latents, we concatenate them
if is_latents_cached: self.latents: Union[torch.Tensor, None] = None
self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items]) if is_latents_cached:
self.control_tensor: Union[torch.Tensor, None] = None self.latents = torch.cat([x.get_latent().unsqueeze(0) for x in self.file_items])
if self.file_items[0].control_tensor is not None: self.control_tensor: Union[torch.Tensor, None] = None
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items]) if self.file_items[0].control_tensor is not None:
self.control_tensor = torch.cat([x.control_tensor.unsqueeze(0) for x in self.file_items])
except Exception as e:
print(e)
raise e
def get_is_reg_list(self): def get_is_reg_list(self):
return [x.is_reg for x in self.file_items] return [x.is_reg for x in self.file_items]

View File

@@ -29,11 +29,13 @@ if TYPE_CHECKING:
transforms_dict = { transforms_dict = {
'ColorJitter': transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.01), 'ColorJitter': transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.03),
'RandomEqualize': transforms.RandomEqualize(p=0.2), 'RandomEqualize': transforms.RandomEqualize(p=0.2),
} }
caption_ext_list = ['txt', 'json', 'caption'] caption_ext_list = ['txt', 'json', 'caption']
class CaptionMixin: class CaptionMixin:
def get_caption_item(self: 'AiToolkitDataset', index): def get_caption_item(self: 'AiToolkitDataset', index):
if not hasattr(self, 'caption_type'): if not hasattr(self, 'caption_type'):
@@ -106,66 +108,96 @@ class BucketsMixin:
self.batch_indices: List[List[int]] = [] self.batch_indices: List[List[int]] = []
def build_batch_indices(self: 'AiToolkitDataset'): def build_batch_indices(self: 'AiToolkitDataset'):
self.batch_indices = []
for key, bucket in self.buckets.items(): for key, bucket in self.buckets.items():
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
batch = bucket.file_list_idx[start_idx:end_idx] batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch) self.batch_indices.append(batch)
def setup_buckets(self: 'AiToolkitDataset'): def setup_buckets(self: 'AiToolkitDataset', quiet=False):
if not hasattr(self, 'file_list'): if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}') raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
if not hasattr(self, 'dataset_config'): if not hasattr(self, 'dataset_config'):
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
if self.epoch_num > 0 and self.dataset_config.poi is None:
# no need to rebuild buckets for now
# todo handle random cropping for buckets
return
self.buckets = {} # clear it
config: 'DatasetConfig' = self.dataset_config config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution resolution = config.resolution
bucket_tolerance = config.bucket_tolerance bucket_tolerance = config.bucket_tolerance
file_list: List['FileItemDTO'] = self.file_list file_list: List['FileItemDTO'] = self.file_list
total_pixels = resolution * resolution
# for file_item in enumerate(file_list): # for file_item in enumerate(file_list):
for idx, file_item in enumerate(file_list): for idx, file_item in enumerate(file_list):
file_item: 'FileItemDTO' = file_item file_item: 'FileItemDTO' = file_item
width = file_item.crop_width width = int(file_item.width * file_item.dataset_config.scale)
height = file_item.crop_height height = int(file_item.height * file_item.dataset_config.scale)
bucket_resolution = get_bucket_for_image_size(width, height, resolution=resolution, if file_item.has_point_of_interest:
divisibility=bucket_tolerance) # let the poi module handle the bucketing
file_item.setup_poi_bucket()
# set the scaling height and with to match smallest size, and keep aspect ratio
if width > height:
file_item.scale_to_height = bucket_resolution["height"]
file_item.scale_to_width = int(width * (bucket_resolution["height"] / height))
else: else:
file_item.scale_to_width = bucket_resolution["width"] bucket_resolution = get_bucket_for_image_size(
file_item.scale_to_height = int(height * (bucket_resolution["width"] / width)) width, height,
resolution=resolution,
divisibility=bucket_tolerance
)
file_item.crop_height = bucket_resolution["height"] # Calculate scale factors for width and height
file_item.crop_width = bucket_resolution["width"] width_scale_factor = bucket_resolution["width"] / width
height_scale_factor = bucket_resolution["height"] / height
new_width = bucket_resolution["width"] # Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
new_height = bucket_resolution["height"] max_scale_factor = max(width_scale_factor, height_scale_factor)
file_item.scale_to_width = int(width * max_scale_factor)
file_item.scale_to_height = int(height * max_scale_factor)
file_item.crop_height = bucket_resolution["height"]
file_item.crop_width = bucket_resolution["width"]
new_width = bucket_resolution["width"]
new_height = bucket_resolution["height"]
if self.dataset_config.random_crop:
# random crop
crop_x = random.randint(0, file_item.scale_to_width - new_width)
crop_y = random.randint(0, file_item.scale_to_height - new_height)
file_item.crop_x = crop_x
file_item.crop_y = crop_y
else:
# do central crop
file_item.crop_x = int((file_item.scale_to_width - new_width) / 2)
file_item.crop_y = int((file_item.scale_to_height - new_height) / 2)
if file_item.crop_y < 0 or file_item.crop_x < 0:
print('debug')
# check if bucket exists, if not, create it # check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}' bucket_key = f'{file_item.crop_width}x{file_item.crop_height}'
if bucket_key not in self.buckets: if bucket_key not in self.buckets:
self.buckets[bucket_key] = Bucket(new_width, new_height) self.buckets[bucket_key] = Bucket(file_item.crop_width, file_item.crop_height)
self.buckets[bucket_key].file_list_idx.append(idx) self.buckets[bucket_key].file_list_idx.append(idx)
# print the buckets # print the buckets
self.build_batch_indices() self.build_batch_indices()
name = f"{os.path.basename(self.dataset_path)} ({self.resolution})" if not quiet:
print(f'Bucket sizes for {self.dataset_path}:') print(f'Bucket sizes for {self.dataset_path}:')
for key, bucket in self.buckets.items(): for key, bucket in self.buckets.items():
print(f'{key}: {len(bucket.file_list_idx)} files') print(f'{key}: {len(bucket.file_list_idx)} files')
print(f'{len(self.buckets)} buckets made') print(f'{len(self.buckets)} buckets made')
# file buckets made
class CaptionProcessingDTOMixin: class CaptionProcessingDTOMixin:
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
# todo allow for loading from sd-scripts style dict # todo allow for loading from sd-scripts style dict
def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]): def load_caption(self: 'FileItemDTO', caption_dict: Union[dict, None]):
@@ -281,10 +313,19 @@ class ImageProcessingDTOMixin:
img.transpose(Image.FLIP_TOP_BOTTOM) img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets: if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item # scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) # crop to x_crop, y_crop, x_crop + crop_width, y_crop + crop_height
if img.width < self.crop_x + self.crop_width or img.height < self.crop_y + self.crop_height:
print('size mismatch')
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
else: else:
# Downscale the source image first # Downscale the source image first
# TODO this is nto right # TODO this is nto right
@@ -371,7 +412,14 @@ class ControlFileItemDTOMixin:
if self.dataset_config.buckets: if self.dataset_config.buckets:
# scale and crop based on file item # scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC) img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((self.crop_height, self.crop_width))(img) # img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else: else:
raise Exception("Control images not supported for non-bucket datasets") raise Exception("Control images not supported for non-bucket datasets")
@@ -381,6 +429,93 @@ class ControlFileItemDTOMixin:
self.control_tensor = None self.control_tensor = None
class PoiFileItemDTOMixin:
# Point of interest bounding box. Allows for dynamic cropping without cropping out the main subject
# items in the poi will always be inside the image when random cropping
def __init__(self: 'FileItemDTO', *args, **kwargs):
if hasattr(super(), '__init__'):
super().__init__(*args, **kwargs)
# poi is a name of the box point of interest in the caption json file
dataset_config = kwargs.get('dataset_config', None)
path = kwargs.get('path', None)
self.poi: Union[str, None] = dataset_config.poi
self.has_point_of_interest = self.poi is not None
self.poi_x: Union[int, None] = None
self.poi_y: Union[int, None] = None
self.poi_width: Union[int, None] = None
self.poi_height: Union[int, None] = None
if self.poi is not None:
# make sure latent caching is off
if dataset_config.cache_latents or dataset_config.cache_latents_to_disk:
raise Exception(
f"Error: poi is not supported when caching latents. Please set cache_latents and cache_latents_to_disk to False in the dataset config"
)
# make sure we are loading through json
if dataset_config.caption_ext != 'json':
raise Exception(
f"Error: poi is only supported when using json captions. Please set caption_ext to json in the dataset config"
)
self.poi = self.poi.strip()
# get the caption path
file_path_no_ext = os.path.splitext(path)[0]
caption_path = file_path_no_ext + '.json'
if not os.path.exists(caption_path):
raise Exception(f"Error: caption file not found for poi: {caption_path}")
with open(caption_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'poi' not in json_data:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
if self.poi not in json_data['poi']:
raise Exception(f"Error: poi not found in caption file: {caption_path}")
# poi has, x, y, width, height
poi = json_data['poi'][self.poi]
self.poi_x = int(poi['x'])
self.poi_y = int(poi['y'])
self.poi_width = int(poi['width'])
self.poi_height = int(poi['height'])
def setup_poi_bucket(self: 'FileItemDTO'):
# we are using poi, so we need to calculate the bucket based on the poi
resolution = self.dataset_config.resolution
bucket_tolerance = self.dataset_config.bucket_tolerance
initial_width = int(self.width * self.dataset_config.scale)
initial_height = int(self.height * self.dataset_config.scale)
poi_x = int(self.poi_x * self.dataset_config.scale)
poi_y = int(self.poi_y * self.dataset_config.scale)
poi_width = int(self.poi_width * self.dataset_config.scale)
poi_height = int(self.poi_height * self.dataset_config.scale)
# todo handle a poi that is smaller than resolution
# determine new cropping
crop_left = random.randint(0, poi_x)
crop_right = random.randint(poi_x + poi_width, initial_width)
crop_top = random.randint(0, poi_y)
crop_bottom = random.randint(poi_y + poi_height, initial_height)
new_width = crop_right - crop_left
new_height = crop_bottom - crop_top
bucket_resolution = get_bucket_for_image_size(
new_width, new_height,
resolution=resolution,
divisibility=bucket_tolerance
)
width_scale_factor = bucket_resolution["width"] / new_width
height_scale_factor = bucket_resolution["height"] / new_height
# Use the maximum of the scale factors to ensure both dimensions are scaled above the bucket resolution
max_scale_factor = max(width_scale_factor, height_scale_factor)
self.scale_to_width = int(initial_width * max_scale_factor)
self.scale_to_height = int(initial_height * max_scale_factor)
self.crop_width = bucket_resolution['width']
self.crop_height = bucket_resolution['height']
self.crop_x = int(crop_left * max_scale_factor)
self.crop_y = int(crop_top * max_scale_factor)
class ArgBreakMixin: class ArgBreakMixin:
# just stops super calls form hitting object # just stops super calls form hitting object
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):