diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index dfa0a967..80043ae5 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -288,7 +288,10 @@ class BaseSDTrainProcess(BaseTrainProcess): imgs, prompts, dataset_config = batch # convert the 0 or 1 for is reg to a bool list - is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) + if isinstance(dataset_config, list): + is_reg_list = [x.get('is_reg', 0) for x in dataset_config] + else: + is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])]) if isinstance(is_reg_list, torch.Tensor): is_reg_list = is_reg_list.numpy().tolist() is_reg_list = [bool(x) for x in is_reg_list] diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py new file mode 100644 index 00000000..85a21ef7 --- /dev/null +++ b/testing/test_bucket_dataloader.py @@ -0,0 +1,37 @@ +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm +# make sure we can import from the toolkit +import sys +import os + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets +from toolkit.config_modules import DatasetConfig +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('dataset_folder', type=str, default='input') + +args = parser.parse_args() + +dataset_folder = args.dataset_folder +resolution = 512 +bucket_tolerance = 64 +batch_size = 4 + +dataset_config = DatasetConfig( + folder_path=dataset_folder, + resolution=resolution, + caption_type='txt', + default_caption='default', + buckets=True, + bucket_tolerance=bucket_tolerance, +) + +dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size) + +# run through an epoch ang check sizes +for batch in dataloader: + print(list(batch[0].shape)) + +print('done') diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 1bff598d..6d90be9d 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -169,6 +169,7 @@ class DatasetConfig: self.resolution: int = kwargs.get('resolution', 512) self.scale: float = kwargs.get('scale', 1.0) self.buckets: bool = kwargs.get('buckets', False) + self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64) self.is_reg: bool = kwargs.get('is_reg', False) diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index eabbe8d2..5393aa8a 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -4,6 +4,7 @@ from typing import List import cv2 import numpy as np +import torch from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms @@ -11,15 +12,9 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset from tqdm import tqdm import albumentations as A +from toolkit import image_utils from toolkit.config_modules import DatasetConfig -from toolkit.dataloader_mixins import CaptionMixin - -BUCKET_STEPS = 64 - -def get_bucket_sizes_for_resolution(resolution: int) -> List[int]: - # make sure resolution is divisible by 8 - if resolution % 8 != 0: - resolution = resolution - (resolution % 8) +from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin class ImageDataset(Dataset, CaptionMixin): @@ -291,32 +286,74 @@ class PairedImageDataset(Dataset): return img, prompt, (self.neg_weight, self.pos_weight) -class AiToolkitDataset(Dataset, CaptionMixin): - def __init__(self, dataset_config: 'DatasetConfig'): +printed_messages = [] + + +def print_once(msg): + global printed_messages + if msg not in printed_messages: + print(msg) + printed_messages.append(msg) + + +class FileItem: + def __init__(self, **kwargs): + self.path = kwargs.get('path', None) + self.width = kwargs.get('width', None) + self.height = kwargs.get('height', None) + # we scale first, then crop + self.scale_to_width = kwargs.get('scale_to_width', self.width) + self.scale_to_height = kwargs.get('scale_to_height', self.height) + # crop values are from scaled size + self.crop_x = kwargs.get('crop_x', 0) + self.crop_y = kwargs.get('crop_y', 0) + self.crop_width = kwargs.get('crop_width', self.scale_to_width) + self.crop_height = kwargs.get('crop_height', self.scale_to_height) + + +class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin): + file_list: List['FileItem'] = [] + + def __init__(self, dataset_config: 'DatasetConfig', batch_size=1): + super().__init__() self.dataset_config = dataset_config self.folder_path = dataset_config.folder_path self.caption_type = dataset_config.caption_type self.default_caption = dataset_config.default_caption self.random_scale = dataset_config.random_scale self.scale = dataset_config.scale + self.batch_size = batch_size # we always random crop if random scale is enabled self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop self.resolution = dataset_config.resolution # get the file list - self.file_list = [ + file_list = [ os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')) ] # this might take a while print(f" - Preprocessing image dimensions") - new_file_list = [] bad_count = 0 - for file in tqdm(self.file_list): - img = Image.open(file) - if int(min(img.size) * self.scale) >= self.resolution: - new_file_list.append(file) + for file in tqdm(file_list): + try: + w, h = image_utils.get_image_size(file) + except image_utils.UnknownImageFormat: + print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \ + f'This process is faster for png, jpeg') + img = Image.open(file) + h, w = img.size + if int(min(h, w) * self.scale) >= self.resolution: + self.file_list.append( + FileItem( + path=file, + width=w, + height=h, + scale_to_width=int(w * self.scale), + scale_to_height=int(h * self.scale), + ) + ) else: bad_count += 1 @@ -324,35 +361,57 @@ class AiToolkitDataset(Dataset, CaptionMixin): print(f" - Found {bad_count} images that are too small") assert len(self.file_list) > 0, f"no images found in {self.folder_path}" + if self.dataset_config.buckets: + # setup buckets + self.setup_buckets() + self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1] ]) def __len__(self): + if self.dataset_config.buckets: + return len(self.batch_indices) return len(self.file_list) - def __getitem__(self, index): - img_path = self.file_list[index] - img = exif_transpose(Image.open(img_path)).convert('RGB') + def _get_single_item(self, index): + file_item = self.file_list[index] + # todo make sure this matches + img = exif_transpose(Image.open(file_item.path)).convert('RGB') + w, h = img.size + if w > h and file_item.scale_to_width < file_item.scale_to_height: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}") + elif h > w and file_item.scale_to_height < file_item.scale_to_width: + # throw error, they should match + raise ValueError( + f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}") # Downscale the source image first img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC) min_img_size = min(img.size) - if self.random_crop: - if self.random_scale and min_img_size > self.resolution: - if min_img_size < self.resolution: - print( - f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}") - scale_size = self.resolution - else: - scale_size = random.randint(self.resolution, int(min_img_size)) - img = img.resize((scale_size, scale_size), Image.BICUBIC) - img = transforms.RandomCrop(self.resolution)(img) + if self.dataset_config.buckets: + # todo allow scaling and cropping, will be hard to add + # scale and crop based on file item + img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC) + img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img) else: - img = transforms.CenterCrop(min_img_size)(img) - img = img.resize((self.resolution, self.resolution), Image.BICUBIC) + if self.random_crop: + if self.random_scale and min_img_size > self.resolution: + if min_img_size < self.resolution: + print( + f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}") + scale_size = self.resolution + else: + scale_size = random.randint(self.resolution, int(min_img_size)) + img = img.resize((scale_size, scale_size), Image.BICUBIC) + img = transforms.RandomCrop(self.resolution)(img) + else: + img = transforms.CenterCrop(min_img_size)(img) + img = img.resize((self.resolution, self.resolution), Image.BICUBIC) img = self.transform(img) @@ -367,6 +426,31 @@ class AiToolkitDataset(Dataset, CaptionMixin): else: return img, dataset_config_dict + def __getitem__(self, item): + if self.dataset_config.buckets: + # we collate ourselves + idx_list = self.batch_indices[item] + tensor_list = [] + prompt_list = [] + dataset_config_dict_list = [] + for idx in idx_list: + if self.caption_type is not None: + img, prompt, dataset_config_dict = self._get_single_item(idx) + prompt_list.append(prompt) + dataset_config_dict_list.append(dataset_config_dict) + else: + img, dataset_config_dict = self._get_single_item(idx) + dataset_config_dict_list.append(dataset_config_dict) + tensor_list.append(img.unsqueeze(0)) + + if self.caption_type is not None: + return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list + else: + return torch.cat(tensor_list, dim=0), dataset_config_dict_list + else: + # Dataloader is batching + return self._get_single_item(item) + def get_dataloader_from_datasets(dataset_options, batch_size=1): # TODO do bucketing @@ -374,22 +458,43 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1): return None datasets = [] + has_buckets = False for dataset_option in dataset_options: if isinstance(dataset_option, DatasetConfig): config = dataset_option else: config = DatasetConfig(**dataset_option) if config.type == 'image': - dataset = AiToolkitDataset(config) + dataset = AiToolkitDataset(config, batch_size=batch_size) datasets.append(dataset) + if config.buckets: + has_buckets = True else: raise ValueError(f"invalid dataset type: {config.type}") concatenated_dataset = ConcatDataset(datasets) - data_loader = DataLoader( - concatenated_dataset, - batch_size=batch_size, - shuffle=True, - num_workers=2 - ) + if has_buckets: + # make sure they all have buckets + for dataset in datasets: + assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" + + def custom_collate_fn(batch): + # just return as is + return batch + + data_loader = DataLoader( + concatenated_dataset, + batch_size=None, # we batch in the dataloader + drop_last=False, + shuffle=True, + collate_fn=custom_collate_fn, # Use the custom collate function + num_workers=2 + ) + else: + data_loader = DataLoader( + concatenated_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=2 + ) return data_loader diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 913075fa..82b39e11 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1,4 +1,5 @@ import os +from typing import TYPE_CHECKING, List, Dict class CaptionMixin: @@ -9,14 +10,16 @@ class CaptionMixin: raise Exception('file_list not found on class instance') img_path_or_tuple = self.file_list[index] if isinstance(img_path_or_tuple, tuple): + img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path # check if either has a prompt file - path_no_ext = os.path.splitext(img_path_or_tuple[0])[0] + path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' if not os.path.exists(prompt_path): - path_no_ext = os.path.splitext(img_path_or_tuple[1])[0] + img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path + path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' else: - img_path = img_path_or_tuple + img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path # see if prompt file exists path_no_ext = os.path.splitext(img_path)[0] prompt_path = path_no_ext + '.txt' @@ -41,3 +44,97 @@ class CaptionMixin: if hasattr(self, 'default_caption'): prompt = self.default_caption return prompt + + +if TYPE_CHECKING: + from toolkit.config_modules import DatasetConfig + from toolkit.data_loader import FileItem + + +class Bucket: + def __init__(self, width: int, height: int): + self.width = width + self.height = height + self.file_list_idx: List[int] = [] + + +class BucketsMixin: + def __init__(self): + self.buckets: Dict[str, Bucket] = {} + self.batch_indices: List[List[int]] = [] + + def build_batch_indices(self): + for key, bucket in self.buckets.items(): + for start_idx in range(0, len(bucket.file_list_idx), self.batch_size): + end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx)) + batch = bucket.file_list_idx[start_idx:end_idx] + self.batch_indices.append(batch) + + def setup_buckets(self): + if not hasattr(self, 'file_list'): + raise Exception(f'file_list not found on class instance {self.__class__.__name__}') + if not hasattr(self, 'dataset_config'): + raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}') + + config: 'DatasetConfig' = self.dataset_config + resolution = config.resolution + bucket_tolerance = config.bucket_tolerance + file_list: List['FileItem'] = self.file_list + + # make sure out resolution is divisible by bucket_tolerance + if resolution % bucket_tolerance != 0: + # reduce it to the nearest divisible number + resolution = resolution - (resolution % bucket_tolerance) + + # for file_item in enumerate(file_list): + for idx, file_item in enumerate(file_list): + width = file_item.crop_width + height = file_item.crop_height + + # determine new size, smallest dimension should be equal to resolution + # the other dimension should be the same ratio it is now (bigger) + new_width = resolution + new_height = resolution + new_x = file_item.crop_x + new_y = file_item.crop_y + if width > height: + # scale width to match new resolution, + new_width = int(width * (resolution / height)) + # make sure new_width is divisible by bucket_tolerance + if new_width % bucket_tolerance != 0: + # reduce it to the nearest divisible number + reduction = new_width % bucket_tolerance + new_width = new_width - reduction + # adjust the new x position so we evenly crop + new_x = int(new_x + (reduction / 2)) + elif height > width: + # scale height to match new resolution + new_height = int(height * (resolution / width)) + # make sure new_height is divisible by bucket_tolerance + if new_height % bucket_tolerance != 0: + # reduce it to the nearest divisible number + reduction = new_height % bucket_tolerance + new_height = new_height - reduction + # adjust the new x position so we evenly crop + new_y = int(new_y + (reduction / 2)) + + # add info to file + file_item.crop_x = new_x + file_item.crop_y = new_y + file_item.crop_width = new_width + file_item.crop_height = new_height + + # check if bucket exists, if not, create it + bucket_key = f'{new_width}x{new_height}' + if bucket_key not in self.buckets: + self.buckets[bucket_key] = Bucket(new_width, new_height) + self.buckets[bucket_key].file_list_idx.append(idx) + + # print the buckets + self.build_batch_indices() + print(f'Bucket sizes for {self.__class__.__name__}:') + for key, bucket in self.buckets.items(): + print(f'{key}: {len(bucket.file_list_idx)} files') + print(f'{len(self.buckets)} buckets made') + + # file buckets made diff --git a/toolkit/image_utils.py b/toolkit/image_utils.py new file mode 100644 index 00000000..cb74ed4f --- /dev/null +++ b/toolkit/image_utils.py @@ -0,0 +1,422 @@ +# ref https://github.com/scardine/image_size/blob/master/get_image_size.py +import collections +import json +import os +import io +import struct + +FILE_UNKNOWN = "Sorry, don't know how to get size for this file." + + +class UnknownImageFormat(Exception): + pass + + +types = collections.OrderedDict() +BMP = types['BMP'] = 'BMP' +GIF = types['GIF'] = 'GIF' +ICO = types['ICO'] = 'ICO' +JPEG = types['JPEG'] = 'JPEG' +PNG = types['PNG'] = 'PNG' +TIFF = types['TIFF'] = 'TIFF' + +image_fields = ['path', 'type', 'file_size', 'width', 'height'] + + +class Image(collections.namedtuple('Image', image_fields)): + + def to_str_row(self): + return ("%d\t%d\t%d\t%s\t%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + )) + + def to_str_row_verbose(self): + return ("%d\t%d\t%d\t%s\t%s\t##%s" % ( + self.width, + self.height, + self.file_size, + self.type, + self.path.replace('\t', '\\t'), + self)) + + def to_str_json(self, indent=None): + return json.dumps(self._asdict(), indent=indent) + + +def get_image_size(file_path): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + """ + img = get_image_metadata(file_path) + return (img.width, img.height) + + +def get_image_size_from_bytesio(input, size): + """ + Return (width, height) for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + """ + img = get_image_metadata_from_bytesio(input, size) + return (img.width, img.height) + + +def get_image_metadata(file_path): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + size = os.path.getsize(file_path) + + # be explicit with open arguments - we need binary mode + with io.open(file_path, "rb") as input: + return get_image_metadata_from_bytesio(input, size, file_path) + + +def get_image_metadata_from_bytesio(input, size, file_path=None): + """ + Return an `Image` object for a given img file content - no external + dependencies except the os and struct builtin modules + + Args: + input (io.IOBase): io object support read & seek + size (int): size of buffer in byte + file_path (str): path to an image file + + Returns: + Image: (path, type, file_size, width, height) + """ + height = -1 + width = -1 + data = input.read(26) + msg = " raised while trying to decode as JPEG." + + if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'): + # GIFs + imgtype = GIF + w, h = struct.unpack("= 24) and data.startswith(b'\211PNG\r\n\032\n') + and (data[12:16] == b'IHDR')): + # PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[16:24]) + width = int(w) + height = int(h) + elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'): + # older PNGs + imgtype = PNG + w, h = struct.unpack(">LL", data[8:16]) + width = int(w) + height = int(h) + elif (size >= 2) and data.startswith(b'\377\330'): + # JPEG + imgtype = JPEG + input.seek(0) + input.read(2) + b = input.read(1) + try: + while (b and ord(b) != 0xDA): + while (ord(b) != 0xFF): + b = input.read(1) + while (ord(b) == 0xFF): + b = input.read(1) + if (ord(b) >= 0xC0 and ord(b) <= 0xC3): + input.read(3) + h, w = struct.unpack(">HH", input.read(4)) + break + else: + input.read( + int(struct.unpack(">H", input.read(2))[0]) - 2) + b = input.read(1) + width = int(w) + height = int(h) + except struct.error: + raise UnknownImageFormat("StructError" + msg) + except ValueError: + raise UnknownImageFormat("ValueError" + msg) + except Exception as e: + raise UnknownImageFormat(e.__class__.__name__ + msg) + elif (size >= 26) and data.startswith(b'BM'): + # BMP + imgtype = 'BMP' + headersize = struct.unpack("= 40: + w, h = struct.unpack("= 8) and data[:4] in (b"II\052\000", b"MM\000\052"): + # Standard TIFF, big- or little-endian + # BigTIFF and other different but TIFF-like formats are not + # supported currently + imgtype = TIFF + byteOrder = data[:2] + boChar = ">" if byteOrder == "MM" else "<" + # maps TIFF type id to size (in bytes) + # and python format char for struct + tiffTypes = { + 1: (1, boChar + "B"), # BYTE + 2: (1, boChar + "c"), # ASCII + 3: (2, boChar + "H"), # SHORT + 4: (4, boChar + "L"), # LONG + 5: (8, boChar + "LL"), # RATIONAL + 6: (1, boChar + "b"), # SBYTE + 7: (1, boChar + "c"), # UNDEFINED + 8: (2, boChar + "h"), # SSHORT + 9: (4, boChar + "l"), # SLONG + 10: (8, boChar + "ll"), # SRATIONAL + 11: (4, boChar + "f"), # FLOAT + 12: (8, boChar + "d") # DOUBLE + } + ifdOffset = struct.unpack(boChar + "L", data[4:8])[0] + try: + countSize = 2 + input.seek(ifdOffset) + ec = input.read(countSize) + ifdEntryCount = struct.unpack(boChar + "H", ec)[0] + # 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4 + # bytes: value offset + ifdEntrySize = 12 + for i in range(ifdEntryCount): + entryOffset = ifdOffset + countSize + i * ifdEntrySize + input.seek(entryOffset) + tag = input.read(2) + tag = struct.unpack(boChar + "H", tag)[0] + if(tag == 256 or tag == 257): + # if type indicates that value fits into 4 bytes, value + # offset is not an offset but value itself + type = input.read(2) + type = struct.unpack(boChar + "H", type)[0] + if type not in tiffTypes: + raise UnknownImageFormat( + "Unkown TIFF field type:" + + str(type)) + typeSize = tiffTypes[type][0] + typeChar = tiffTypes[type][1] + input.seek(entryOffset + 8) + value = input.read(typeSize) + value = int(struct.unpack(typeChar, value)[0]) + if tag == 256: + width = value + else: + height = value + if width > -1 and height > -1: + break + except Exception as e: + raise UnknownImageFormat(str(e)) + elif size >= 2: + # see http://en.wikipedia.org/wiki/ICO_(file_format) + imgtype = 'ICO' + input.seek(0) + reserved = input.read(2) + if 0 != struct.unpack(" 1: + import warnings + warnings.warn("ICO File contains more than one image") + # http://msdn.microsoft.com/en-us/library/ms997538.aspx + w = input.read(1) + h = input.read(1) + width = ord(w) + height = ord(h) + else: + raise UnknownImageFormat(FILE_UNKNOWN) + + return Image(path=file_path, + type=imgtype, + file_size=size, + width=width, + height=height) + + +import unittest + + +class Test_get_image_size(unittest.TestCase): + data = [{ + 'path': 'lookmanodeps.png', + 'width': 251, + 'height': 208, + 'file_size': 22228, + 'type': 'PNG'}] + + def setUp(self): + pass + + def test_get_image_size_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_size_from_bytesio(fp, sz) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def test_get_image_metadata_from_bytesio(self): + img = self.data[0] + p = img['path'] + with io.open(p, 'rb') as fp: + b = fp.read() + fp = io.BytesIO(b) + sz = len(b) + output = get_image_metadata_from_bytesio(fp, sz) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), None if field == 'path' else img[field]) + + def test_get_image_metadata(self): + img = self.data[0] + output = get_image_metadata(img['path']) + self.assertTrue(output) + for field in image_fields: + self.assertEqual(getattr(output, field), img[field]) + + def test_get_image_metadata__ENOENT_OSError(self): + with self.assertRaises(OSError): + get_image_metadata('THIS_DOES_NOT_EXIST') + + def test_get_image_metadata__not_an_image_UnknownImageFormat(self): + with self.assertRaises(UnknownImageFormat): + get_image_metadata('README.rst') + + def test_get_image_size(self): + img = self.data[0] + output = get_image_size(img['path']) + self.assertTrue(output) + self.assertEqual(output, + (img['width'], + img['height'])) + + def tearDown(self): + pass + + +def main(argv=None): + """ + Print image metadata fields for the given file path. + + Keyword Arguments: + argv (list): commandline arguments (e.g. sys.argv[1:]) + Returns: + int: zero for OK + """ + import logging + import optparse + import sys + + prs = optparse.OptionParser( + usage="%prog [-v|--verbose] [--json|--json-indent] []", + description="Print metadata for the given image paths " + "(without image library bindings).") + + prs.add_option('--json', + dest='json', + action='store_true') + prs.add_option('--json-indent', + dest='json_indent', + action='store_true') + + prs.add_option('-v', '--verbose', + dest='verbose', + action='store_true',) + prs.add_option('-q', '--quiet', + dest='quiet', + action='store_true',) + prs.add_option('-t', '--test', + dest='run_tests', + action='store_true',) + + argv = list(argv) if argv is not None else sys.argv[1:] + (opts, args) = prs.parse_args(args=argv) + loglevel = logging.INFO + if opts.verbose: + loglevel = logging.DEBUG + elif opts.quiet: + loglevel = logging.ERROR + logging.basicConfig(level=loglevel) + log = logging.getLogger() + log.debug('argv: %r', argv) + log.debug('opts: %r', opts) + log.debug('args: %r', args) + + if opts.run_tests: + import sys + sys.argv = [sys.argv[0]] + args + import unittest + return unittest.main() + + output_func = Image.to_str_row + if opts.json_indent: + import functools + output_func = functools.partial(Image.to_str_json, indent=2) + elif opts.json: + output_func = Image.to_str_json + elif opts.verbose: + output_func = Image.to_str_row_verbose + + EX_OK = 0 + EX_NOT_OK = 2 + + if len(args) < 1: + prs.print_help() + print('') + prs.error("You must specify one or more paths to image files") + + errors = [] + for path_arg in args: + try: + img = get_image_metadata(path_arg) + print(output_func(img)) + except KeyboardInterrupt: + raise + except OSError as e: + log.error((path_arg, e)) + errors.append((path_arg, e)) + except Exception as e: + log.exception(e) + errors.append((path_arg, e)) + pass + if len(errors): + import pprint + print("ERRORS", file=sys.stderr) + print("======", file=sys.stderr) + print(pprint.pformat(errors, indent=2), file=sys.stderr) + return EX_NOT_OK + return EX_OK + + +if __name__ == "__main__": + import sys + sys.exit(main(argv=sys.argv[1:])) \ No newline at end of file