mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-13 14:39:50 +00:00
Built base interfaces for a DTO to handle batch infomation transports for the dataloader
This commit is contained in:
@@ -15,6 +15,8 @@ import albumentations as A
|
||||
from toolkit import image_utils
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO
|
||||
|
||||
|
||||
|
||||
class ImageDataset(Dataset, CaptionMixin):
|
||||
@@ -296,20 +298,6 @@ def print_once(msg):
|
||||
printed_messages.append(msg)
|
||||
|
||||
|
||||
class FileItem:
|
||||
def __init__(self, **kwargs):
|
||||
self.path = kwargs.get('path', None)
|
||||
self.width = kwargs.get('width', None)
|
||||
self.height = kwargs.get('height', None)
|
||||
# we scale first, then crop
|
||||
self.scale_to_width = kwargs.get('scale_to_width', self.width)
|
||||
self.scale_to_height = kwargs.get('scale_to_height', self.height)
|
||||
# crop values are from scaled size
|
||||
self.crop_x = kwargs.get('crop_x', 0)
|
||||
self.crop_y = kwargs.get('crop_y', 0)
|
||||
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
|
||||
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
|
||||
|
||||
|
||||
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
|
||||
@@ -325,7 +313,7 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
# we always random crop if random scale is enabled
|
||||
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
|
||||
self.resolution = dataset_config.resolution
|
||||
self.file_list: List['FileItem'] = []
|
||||
self.file_list: List['FileItemDTO'] = []
|
||||
|
||||
# get the file list
|
||||
file_list = [
|
||||
@@ -344,14 +332,16 @@ class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
|
||||
f'This process is faster for png, jpeg')
|
||||
img = Image.open(file)
|
||||
h, w = img.size
|
||||
# TODO allow smaller images
|
||||
if int(min(h, w) * self.scale) >= self.resolution:
|
||||
self.file_list.append(
|
||||
FileItem(
|
||||
FileItemDTO(
|
||||
path=file,
|
||||
width=w,
|
||||
height=h,
|
||||
scale_to_width=int(w * self.scale),
|
||||
scale_to_height=int(h * self.scale),
|
||||
dataset_config=dataset_config
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user