Added bucketting capabilities to dataloader. Finally have full planned capability. noice

This commit is contained in:
Jaret Burkett
2023-08-26 16:36:32 -06:00
parent 2cb27c3f57
commit 8105c05c12
6 changed files with 707 additions and 42 deletions

View File

@@ -288,7 +288,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
imgs, prompts, dataset_config = batch
# convert the 0 or 1 for is reg to a bool list
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(dataset_config, list):
is_reg_list = [x.get('is_reg', 0) for x in dataset_config]
else:
is_reg_list = dataset_config.get('is_reg', [0 for _ in range(imgs.shape[0])])
if isinstance(is_reg_list, torch.Tensor):
is_reg_list = is_reg_list.numpy().tolist()
is_reg_list = [bool(x) for x in is_reg_list]

View File

@@ -0,0 +1,37 @@
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
# make sure we can import from the toolkit
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets
from toolkit.config_modules import DatasetConfig
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
args = parser.parse_args()
dataset_folder = args.dataset_folder
resolution = 512
bucket_tolerance = 64
batch_size = 4
dataset_config = DatasetConfig(
folder_path=dataset_folder,
resolution=resolution,
caption_type='txt',
default_caption='default',
buckets=True,
bucket_tolerance=bucket_tolerance,
)
dataloader = get_dataloader_from_datasets([dataset_config], batch_size=batch_size)
# run through an epoch ang check sizes
for batch in dataloader:
print(list(batch[0].shape))
print('done')

View File

@@ -169,6 +169,7 @@ class DatasetConfig:
self.resolution: int = kwargs.get('resolution', 512)
self.scale: float = kwargs.get('scale', 1.0)
self.buckets: bool = kwargs.get('buckets', False)
self.bucket_tolerance: int = kwargs.get('bucket_tolerance', 64)
self.is_reg: bool = kwargs.get('is_reg', False)

View File

@@ -4,6 +4,7 @@ from typing import List
import cv2
import numpy as np
import torch
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
@@ -11,15 +12,9 @@ from torch.utils.data import Dataset, DataLoader, ConcatDataset
from tqdm import tqdm
import albumentations as A
from toolkit import image_utils
from toolkit.config_modules import DatasetConfig
from toolkit.dataloader_mixins import CaptionMixin
BUCKET_STEPS = 64
def get_bucket_sizes_for_resolution(resolution: int) -> List[int]:
# make sure resolution is divisible by 8
if resolution % 8 != 0:
resolution = resolution - (resolution % 8)
from toolkit.dataloader_mixins import CaptionMixin, BucketsMixin
class ImageDataset(Dataset, CaptionMixin):
@@ -291,32 +286,74 @@ class PairedImageDataset(Dataset):
return img, prompt, (self.neg_weight, self.pos_weight)
class AiToolkitDataset(Dataset, CaptionMixin):
def __init__(self, dataset_config: 'DatasetConfig'):
printed_messages = []
def print_once(msg):
global printed_messages
if msg not in printed_messages:
print(msg)
printed_messages.append(msg)
class FileItem:
def __init__(self, **kwargs):
self.path = kwargs.get('path', None)
self.width = kwargs.get('width', None)
self.height = kwargs.get('height', None)
# we scale first, then crop
self.scale_to_width = kwargs.get('scale_to_width', self.width)
self.scale_to_height = kwargs.get('scale_to_height', self.height)
# crop values are from scaled size
self.crop_x = kwargs.get('crop_x', 0)
self.crop_y = kwargs.get('crop_y', 0)
self.crop_width = kwargs.get('crop_width', self.scale_to_width)
self.crop_height = kwargs.get('crop_height', self.scale_to_height)
class AiToolkitDataset(Dataset, CaptionMixin, BucketsMixin):
file_list: List['FileItem'] = []
def __init__(self, dataset_config: 'DatasetConfig', batch_size=1):
super().__init__()
self.dataset_config = dataset_config
self.folder_path = dataset_config.folder_path
self.caption_type = dataset_config.caption_type
self.default_caption = dataset_config.default_caption
self.random_scale = dataset_config.random_scale
self.scale = dataset_config.scale
self.batch_size = batch_size
# we always random crop if random scale is enabled
self.random_crop = self.random_scale if self.random_scale else dataset_config.random_crop
self.resolution = dataset_config.resolution
# get the file list
self.file_list = [
file_list = [
os.path.join(self.folder_path, file) for file in os.listdir(self.folder_path) if
file.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
]
# this might take a while
print(f" - Preprocessing image dimensions")
new_file_list = []
bad_count = 0
for file in tqdm(self.file_list):
img = Image.open(file)
if int(min(img.size) * self.scale) >= self.resolution:
new_file_list.append(file)
for file in tqdm(file_list):
try:
w, h = image_utils.get_image_size(file)
except image_utils.UnknownImageFormat:
print_once(f'Warning: Some images in the dataset cannot be fast read. ' + \
f'This process is faster for png, jpeg')
img = Image.open(file)
h, w = img.size
if int(min(h, w) * self.scale) >= self.resolution:
self.file_list.append(
FileItem(
path=file,
width=w,
height=h,
scale_to_width=int(w * self.scale),
scale_to_height=int(h * self.scale),
)
)
else:
bad_count += 1
@@ -324,35 +361,57 @@ class AiToolkitDataset(Dataset, CaptionMixin):
print(f" - Found {bad_count} images that are too small")
assert len(self.file_list) > 0, f"no images found in {self.folder_path}"
if self.dataset_config.buckets:
# setup buckets
self.setup_buckets()
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]), # normalize to [-1, 1]
])
def __len__(self):
if self.dataset_config.buckets:
return len(self.batch_indices)
return len(self.file_list)
def __getitem__(self, index):
img_path = self.file_list[index]
img = exif_transpose(Image.open(img_path)).convert('RGB')
def _get_single_item(self, index):
file_item = self.file_list[index]
# todo make sure this matches
img = exif_transpose(Image.open(file_item.path)).convert('RGB')
w, h = img.size
if w > h and file_item.scale_to_width < file_item.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
elif h > w and file_item.scale_to_height < file_item.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={file_item.scale_to_width}, file_item.scale_to_height={file_item.scale_to_height}, file_item.path={file_item.path}")
# Downscale the source image first
img = img.resize((int(img.size[0] * self.scale), int(img.size[1] * self.scale)), Image.BICUBIC)
min_img_size = min(img.size)
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={img_path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
if self.dataset_config.buckets:
# todo allow scaling and cropping, will be hard to add
# scale and crop based on file item
img = img.resize((file_item.scale_to_width, file_item.scale_to_height), Image.BICUBIC)
img = transforms.CenterCrop((file_item.crop_height, file_item.crop_width))(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
if self.random_crop:
if self.random_scale and min_img_size > self.resolution:
if min_img_size < self.resolution:
print(
f"Unexpected values: min_img_size={min_img_size}, self.resolution={self.resolution}, image file={file_item.path}")
scale_size = self.resolution
else:
scale_size = random.randint(self.resolution, int(min_img_size))
img = img.resize((scale_size, scale_size), Image.BICUBIC)
img = transforms.RandomCrop(self.resolution)(img)
else:
img = transforms.CenterCrop(min_img_size)(img)
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
img = self.transform(img)
@@ -367,6 +426,31 @@ class AiToolkitDataset(Dataset, CaptionMixin):
else:
return img, dataset_config_dict
def __getitem__(self, item):
if self.dataset_config.buckets:
# we collate ourselves
idx_list = self.batch_indices[item]
tensor_list = []
prompt_list = []
dataset_config_dict_list = []
for idx in idx_list:
if self.caption_type is not None:
img, prompt, dataset_config_dict = self._get_single_item(idx)
prompt_list.append(prompt)
dataset_config_dict_list.append(dataset_config_dict)
else:
img, dataset_config_dict = self._get_single_item(idx)
dataset_config_dict_list.append(dataset_config_dict)
tensor_list.append(img.unsqueeze(0))
if self.caption_type is not None:
return torch.cat(tensor_list, dim=0), prompt_list, dataset_config_dict_list
else:
return torch.cat(tensor_list, dim=0), dataset_config_dict_list
else:
# Dataloader is batching
return self._get_single_item(item)
def get_dataloader_from_datasets(dataset_options, batch_size=1):
# TODO do bucketing
@@ -374,22 +458,43 @@ def get_dataloader_from_datasets(dataset_options, batch_size=1):
return None
datasets = []
has_buckets = False
for dataset_option in dataset_options:
if isinstance(dataset_option, DatasetConfig):
config = dataset_option
else:
config = DatasetConfig(**dataset_option)
if config.type == 'image':
dataset = AiToolkitDataset(config)
dataset = AiToolkitDataset(config, batch_size=batch_size)
datasets.append(dataset)
if config.buckets:
has_buckets = True
else:
raise ValueError(f"invalid dataset type: {config.type}")
concatenated_dataset = ConcatDataset(datasets)
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
if has_buckets:
# make sure they all have buckets
for dataset in datasets:
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"
def custom_collate_fn(batch):
# just return as is
return batch
data_loader = DataLoader(
concatenated_dataset,
batch_size=None, # we batch in the dataloader
drop_last=False,
shuffle=True,
collate_fn=custom_collate_fn, # Use the custom collate function
num_workers=2
)
else:
data_loader = DataLoader(
concatenated_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=2
)
return data_loader

View File

@@ -1,4 +1,5 @@
import os
from typing import TYPE_CHECKING, List, Dict
class CaptionMixin:
@@ -9,14 +10,16 @@ class CaptionMixin:
raise Exception('file_list not found on class instance')
img_path_or_tuple = self.file_list[index]
if isinstance(img_path_or_tuple, tuple):
img_path = img_path_or_tuple[0] if isinstance(img_path_or_tuple[0], str) else img_path_or_tuple[0].path
# check if either has a prompt file
path_no_ext = os.path.splitext(img_path_or_tuple[0])[0]
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
if not os.path.exists(prompt_path):
path_no_ext = os.path.splitext(img_path_or_tuple[1])[0]
img_path = img_path_or_tuple[1] if isinstance(img_path_or_tuple[1], str) else img_path_or_tuple[1].path
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
else:
img_path = img_path_or_tuple
img_path = img_path_or_tuple if isinstance(img_path_or_tuple, str) else img_path_or_tuple.path
# see if prompt file exists
path_no_ext = os.path.splitext(img_path)[0]
prompt_path = path_no_ext + '.txt'
@@ -41,3 +44,97 @@ class CaptionMixin:
if hasattr(self, 'default_caption'):
prompt = self.default_caption
return prompt
if TYPE_CHECKING:
from toolkit.config_modules import DatasetConfig
from toolkit.data_loader import FileItem
class Bucket:
def __init__(self, width: int, height: int):
self.width = width
self.height = height
self.file_list_idx: List[int] = []
class BucketsMixin:
def __init__(self):
self.buckets: Dict[str, Bucket] = {}
self.batch_indices: List[List[int]] = []
def build_batch_indices(self):
for key, bucket in self.buckets.items():
for start_idx in range(0, len(bucket.file_list_idx), self.batch_size):
end_idx = min(start_idx + self.batch_size, len(bucket.file_list_idx))
batch = bucket.file_list_idx[start_idx:end_idx]
self.batch_indices.append(batch)
def setup_buckets(self):
if not hasattr(self, 'file_list'):
raise Exception(f'file_list not found on class instance {self.__class__.__name__}')
if not hasattr(self, 'dataset_config'):
raise Exception(f'dataset_config not found on class instance {self.__class__.__name__}')
config: 'DatasetConfig' = self.dataset_config
resolution = config.resolution
bucket_tolerance = config.bucket_tolerance
file_list: List['FileItem'] = self.file_list
# make sure out resolution is divisible by bucket_tolerance
if resolution % bucket_tolerance != 0:
# reduce it to the nearest divisible number
resolution = resolution - (resolution % bucket_tolerance)
# for file_item in enumerate(file_list):
for idx, file_item in enumerate(file_list):
width = file_item.crop_width
height = file_item.crop_height
# determine new size, smallest dimension should be equal to resolution
# the other dimension should be the same ratio it is now (bigger)
new_width = resolution
new_height = resolution
new_x = file_item.crop_x
new_y = file_item.crop_y
if width > height:
# scale width to match new resolution,
new_width = int(width * (resolution / height))
# make sure new_width is divisible by bucket_tolerance
if new_width % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_width % bucket_tolerance
new_width = new_width - reduction
# adjust the new x position so we evenly crop
new_x = int(new_x + (reduction / 2))
elif height > width:
# scale height to match new resolution
new_height = int(height * (resolution / width))
# make sure new_height is divisible by bucket_tolerance
if new_height % bucket_tolerance != 0:
# reduce it to the nearest divisible number
reduction = new_height % bucket_tolerance
new_height = new_height - reduction
# adjust the new x position so we evenly crop
new_y = int(new_y + (reduction / 2))
# add info to file
file_item.crop_x = new_x
file_item.crop_y = new_y
file_item.crop_width = new_width
file_item.crop_height = new_height
# check if bucket exists, if not, create it
bucket_key = f'{new_width}x{new_height}'
if bucket_key not in self.buckets:
self.buckets[bucket_key] = Bucket(new_width, new_height)
self.buckets[bucket_key].file_list_idx.append(idx)
# print the buckets
self.build_batch_indices()
print(f'Bucket sizes for {self.__class__.__name__}:')
for key, bucket in self.buckets.items():
print(f'{key}: {len(bucket.file_list_idx)} files')
print(f'{len(self.buckets)} buckets made')
# file buckets made

422
toolkit/image_utils.py Normal file
View File

@@ -0,0 +1,422 @@
# ref https://github.com/scardine/image_size/blob/master/get_image_size.py
import collections
import json
import os
import io
import struct
FILE_UNKNOWN = "Sorry, don't know how to get size for this file."
class UnknownImageFormat(Exception):
pass
types = collections.OrderedDict()
BMP = types['BMP'] = 'BMP'
GIF = types['GIF'] = 'GIF'
ICO = types['ICO'] = 'ICO'
JPEG = types['JPEG'] = 'JPEG'
PNG = types['PNG'] = 'PNG'
TIFF = types['TIFF'] = 'TIFF'
image_fields = ['path', 'type', 'file_size', 'width', 'height']
class Image(collections.namedtuple('Image', image_fields)):
def to_str_row(self):
return ("%d\t%d\t%d\t%s\t%s" % (
self.width,
self.height,
self.file_size,
self.type,
self.path.replace('\t', '\\t'),
))
def to_str_row_verbose(self):
return ("%d\t%d\t%d\t%s\t%s\t##%s" % (
self.width,
self.height,
self.file_size,
self.type,
self.path.replace('\t', '\\t'),
self))
def to_str_json(self, indent=None):
return json.dumps(self._asdict(), indent=indent)
def get_image_size(file_path):
"""
Return (width, height) for a given img file content - no external
dependencies except the os and struct builtin modules
"""
img = get_image_metadata(file_path)
return (img.width, img.height)
def get_image_size_from_bytesio(input, size):
"""
Return (width, height) for a given img file content - no external
dependencies except the os and struct builtin modules
Args:
input (io.IOBase): io object support read & seek
size (int): size of buffer in byte
"""
img = get_image_metadata_from_bytesio(input, size)
return (img.width, img.height)
def get_image_metadata(file_path):
"""
Return an `Image` object for a given img file content - no external
dependencies except the os and struct builtin modules
Args:
file_path (str): path to an image file
Returns:
Image: (path, type, file_size, width, height)
"""
size = os.path.getsize(file_path)
# be explicit with open arguments - we need binary mode
with io.open(file_path, "rb") as input:
return get_image_metadata_from_bytesio(input, size, file_path)
def get_image_metadata_from_bytesio(input, size, file_path=None):
"""
Return an `Image` object for a given img file content - no external
dependencies except the os and struct builtin modules
Args:
input (io.IOBase): io object support read & seek
size (int): size of buffer in byte
file_path (str): path to an image file
Returns:
Image: (path, type, file_size, width, height)
"""
height = -1
width = -1
data = input.read(26)
msg = " raised while trying to decode as JPEG."
if (size >= 10) and data[:6] in (b'GIF87a', b'GIF89a'):
# GIFs
imgtype = GIF
w, h = struct.unpack("<HH", data[6:10])
width = int(w)
height = int(h)
elif ((size >= 24) and data.startswith(b'\211PNG\r\n\032\n')
and (data[12:16] == b'IHDR')):
# PNGs
imgtype = PNG
w, h = struct.unpack(">LL", data[16:24])
width = int(w)
height = int(h)
elif (size >= 16) and data.startswith(b'\211PNG\r\n\032\n'):
# older PNGs
imgtype = PNG
w, h = struct.unpack(">LL", data[8:16])
width = int(w)
height = int(h)
elif (size >= 2) and data.startswith(b'\377\330'):
# JPEG
imgtype = JPEG
input.seek(0)
input.read(2)
b = input.read(1)
try:
while (b and ord(b) != 0xDA):
while (ord(b) != 0xFF):
b = input.read(1)
while (ord(b) == 0xFF):
b = input.read(1)
if (ord(b) >= 0xC0 and ord(b) <= 0xC3):
input.read(3)
h, w = struct.unpack(">HH", input.read(4))
break
else:
input.read(
int(struct.unpack(">H", input.read(2))[0]) - 2)
b = input.read(1)
width = int(w)
height = int(h)
except struct.error:
raise UnknownImageFormat("StructError" + msg)
except ValueError:
raise UnknownImageFormat("ValueError" + msg)
except Exception as e:
raise UnknownImageFormat(e.__class__.__name__ + msg)
elif (size >= 26) and data.startswith(b'BM'):
# BMP
imgtype = 'BMP'
headersize = struct.unpack("<I", data[14:18])[0]
if headersize == 12:
w, h = struct.unpack("<HH", data[18:22])
width = int(w)
height = int(h)
elif headersize >= 40:
w, h = struct.unpack("<ii", data[18:26])
width = int(w)
# as h is negative when stored upside down
height = abs(int(h))
else:
raise UnknownImageFormat(
"Unkown DIB header size:" +
str(headersize))
elif (size >= 8) and data[:4] in (b"II\052\000", b"MM\000\052"):
# Standard TIFF, big- or little-endian
# BigTIFF and other different but TIFF-like formats are not
# supported currently
imgtype = TIFF
byteOrder = data[:2]
boChar = ">" if byteOrder == "MM" else "<"
# maps TIFF type id to size (in bytes)
# and python format char for struct
tiffTypes = {
1: (1, boChar + "B"), # BYTE
2: (1, boChar + "c"), # ASCII
3: (2, boChar + "H"), # SHORT
4: (4, boChar + "L"), # LONG
5: (8, boChar + "LL"), # RATIONAL
6: (1, boChar + "b"), # SBYTE
7: (1, boChar + "c"), # UNDEFINED
8: (2, boChar + "h"), # SSHORT
9: (4, boChar + "l"), # SLONG
10: (8, boChar + "ll"), # SRATIONAL
11: (4, boChar + "f"), # FLOAT
12: (8, boChar + "d") # DOUBLE
}
ifdOffset = struct.unpack(boChar + "L", data[4:8])[0]
try:
countSize = 2
input.seek(ifdOffset)
ec = input.read(countSize)
ifdEntryCount = struct.unpack(boChar + "H", ec)[0]
# 2 bytes: TagId + 2 bytes: type + 4 bytes: count of values + 4
# bytes: value offset
ifdEntrySize = 12
for i in range(ifdEntryCount):
entryOffset = ifdOffset + countSize + i * ifdEntrySize
input.seek(entryOffset)
tag = input.read(2)
tag = struct.unpack(boChar + "H", tag)[0]
if(tag == 256 or tag == 257):
# if type indicates that value fits into 4 bytes, value
# offset is not an offset but value itself
type = input.read(2)
type = struct.unpack(boChar + "H", type)[0]
if type not in tiffTypes:
raise UnknownImageFormat(
"Unkown TIFF field type:" +
str(type))
typeSize = tiffTypes[type][0]
typeChar = tiffTypes[type][1]
input.seek(entryOffset + 8)
value = input.read(typeSize)
value = int(struct.unpack(typeChar, value)[0])
if tag == 256:
width = value
else:
height = value
if width > -1 and height > -1:
break
except Exception as e:
raise UnknownImageFormat(str(e))
elif size >= 2:
# see http://en.wikipedia.org/wiki/ICO_(file_format)
imgtype = 'ICO'
input.seek(0)
reserved = input.read(2)
if 0 != struct.unpack("<H", reserved)[0]:
raise UnknownImageFormat(FILE_UNKNOWN)
format = input.read(2)
assert 1 == struct.unpack("<H", format)[0]
num = input.read(2)
num = struct.unpack("<H", num)[0]
if num > 1:
import warnings
warnings.warn("ICO File contains more than one image")
# http://msdn.microsoft.com/en-us/library/ms997538.aspx
w = input.read(1)
h = input.read(1)
width = ord(w)
height = ord(h)
else:
raise UnknownImageFormat(FILE_UNKNOWN)
return Image(path=file_path,
type=imgtype,
file_size=size,
width=width,
height=height)
import unittest
class Test_get_image_size(unittest.TestCase):
data = [{
'path': 'lookmanodeps.png',
'width': 251,
'height': 208,
'file_size': 22228,
'type': 'PNG'}]
def setUp(self):
pass
def test_get_image_size_from_bytesio(self):
img = self.data[0]
p = img['path']
with io.open(p, 'rb') as fp:
b = fp.read()
fp = io.BytesIO(b)
sz = len(b)
output = get_image_size_from_bytesio(fp, sz)
self.assertTrue(output)
self.assertEqual(output,
(img['width'],
img['height']))
def test_get_image_metadata_from_bytesio(self):
img = self.data[0]
p = img['path']
with io.open(p, 'rb') as fp:
b = fp.read()
fp = io.BytesIO(b)
sz = len(b)
output = get_image_metadata_from_bytesio(fp, sz)
self.assertTrue(output)
for field in image_fields:
self.assertEqual(getattr(output, field), None if field == 'path' else img[field])
def test_get_image_metadata(self):
img = self.data[0]
output = get_image_metadata(img['path'])
self.assertTrue(output)
for field in image_fields:
self.assertEqual(getattr(output, field), img[field])
def test_get_image_metadata__ENOENT_OSError(self):
with self.assertRaises(OSError):
get_image_metadata('THIS_DOES_NOT_EXIST')
def test_get_image_metadata__not_an_image_UnknownImageFormat(self):
with self.assertRaises(UnknownImageFormat):
get_image_metadata('README.rst')
def test_get_image_size(self):
img = self.data[0]
output = get_image_size(img['path'])
self.assertTrue(output)
self.assertEqual(output,
(img['width'],
img['height']))
def tearDown(self):
pass
def main(argv=None):
"""
Print image metadata fields for the given file path.
Keyword Arguments:
argv (list): commandline arguments (e.g. sys.argv[1:])
Returns:
int: zero for OK
"""
import logging
import optparse
import sys
prs = optparse.OptionParser(
usage="%prog [-v|--verbose] [--json|--json-indent] <path0> [<pathN>]",
description="Print metadata for the given image paths "
"(without image library bindings).")
prs.add_option('--json',
dest='json',
action='store_true')
prs.add_option('--json-indent',
dest='json_indent',
action='store_true')
prs.add_option('-v', '--verbose',
dest='verbose',
action='store_true',)
prs.add_option('-q', '--quiet',
dest='quiet',
action='store_true',)
prs.add_option('-t', '--test',
dest='run_tests',
action='store_true',)
argv = list(argv) if argv is not None else sys.argv[1:]
(opts, args) = prs.parse_args(args=argv)
loglevel = logging.INFO
if opts.verbose:
loglevel = logging.DEBUG
elif opts.quiet:
loglevel = logging.ERROR
logging.basicConfig(level=loglevel)
log = logging.getLogger()
log.debug('argv: %r', argv)
log.debug('opts: %r', opts)
log.debug('args: %r', args)
if opts.run_tests:
import sys
sys.argv = [sys.argv[0]] + args
import unittest
return unittest.main()
output_func = Image.to_str_row
if opts.json_indent:
import functools
output_func = functools.partial(Image.to_str_json, indent=2)
elif opts.json:
output_func = Image.to_str_json
elif opts.verbose:
output_func = Image.to_str_row_verbose
EX_OK = 0
EX_NOT_OK = 2
if len(args) < 1:
prs.print_help()
print('')
prs.error("You must specify one or more paths to image files")
errors = []
for path_arg in args:
try:
img = get_image_metadata(path_arg)
print(output_func(img))
except KeyboardInterrupt:
raise
except OSError as e:
log.error((path_arg, e))
errors.append((path_arg, e))
except Exception as e:
log.exception(e)
errors.append((path_arg, e))
pass
if len(errors):
import pprint
print("ERRORS", file=sys.stderr)
print("======", file=sys.stderr)
print(pprint.pformat(errors, indent=2), file=sys.stderr)
return EX_NOT_OK
return EX_OK
if __name__ == "__main__":
import sys
sys.exit(main(argv=sys.argv[1:]))