mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-05-01 04:41:13 +00:00
Initial commit
This commit is contained in:
573
metric_depth/zoedepth/data/data_mono.py
Normal file
573
metric_depth/zoedepth/data/data_mono.py
Normal file
@@ -0,0 +1,573 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2022 Intelligent Systems Lab Org
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# File author: Shariq Farooq Bhat
|
||||
|
||||
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.data.distributed
|
||||
from zoedepth.utils.easydict import EasyDict as edict
|
||||
from PIL import Image, ImageOps
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torchvision import transforms
|
||||
|
||||
from zoedepth.utils.config import change_dataset
|
||||
|
||||
from .ddad import get_ddad_loader
|
||||
from .diml_indoor_test import get_diml_indoor_loader
|
||||
from .diml_outdoor_test import get_diml_outdoor_loader
|
||||
from .diode import get_diode_loader
|
||||
from .hypersim import get_hypersim_loader
|
||||
from .ibims import get_ibims_loader
|
||||
from .sun_rgbd_loader import get_sunrgbd_loader
|
||||
from .vkitti import get_vkitti_loader
|
||||
from .vkitti2 import get_vkitti2_loader
|
||||
|
||||
from .preprocess import CropParams, get_white_border, get_black_border
|
||||
|
||||
|
||||
def _is_pil_image(img):
|
||||
return isinstance(img, Image.Image)
|
||||
|
||||
|
||||
def _is_numpy_image(img):
|
||||
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||
|
||||
|
||||
def preprocessing_transforms(mode, **kwargs):
|
||||
return transforms.Compose([
|
||||
ToTensor(mode=mode, **kwargs)
|
||||
])
|
||||
|
||||
|
||||
class DepthDataLoader(object):
|
||||
def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
|
||||
"""
|
||||
Data loader for depth datasets
|
||||
|
||||
Args:
|
||||
config (dict): Config dictionary. Refer to utils/config.py
|
||||
mode (str): "train" or "online_eval"
|
||||
device (str, optional): Device to load the data on. Defaults to 'cpu'.
|
||||
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
|
||||
"""
|
||||
|
||||
self.config = config
|
||||
|
||||
if config.dataset == 'ibims':
|
||||
self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'sunrgbd':
|
||||
self.data = get_sunrgbd_loader(
|
||||
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'diml_indoor':
|
||||
self.data = get_diml_indoor_loader(
|
||||
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'diml_outdoor':
|
||||
self.data = get_diml_outdoor_loader(
|
||||
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if "diode" in config.dataset:
|
||||
self.data = get_diode_loader(
|
||||
config[config.dataset+"_root"], batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'hypersim_test':
|
||||
self.data = get_hypersim_loader(
|
||||
config.hypersim_test_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'vkitti':
|
||||
self.data = get_vkitti_loader(
|
||||
config.vkitti_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'vkitti2':
|
||||
self.data = get_vkitti2_loader(
|
||||
config.vkitti2_root, batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
if config.dataset == 'ddad':
|
||||
self.data = get_ddad_loader(config.ddad_root, resize_shape=(
|
||||
352, 1216), batch_size=1, num_workers=1)
|
||||
return
|
||||
|
||||
img_size = self.config.get("img_size", None)
|
||||
img_size = img_size if self.config.get(
|
||||
"do_input_resize", False) else None
|
||||
|
||||
if transform is None:
|
||||
transform = preprocessing_transforms(mode, size=img_size)
|
||||
|
||||
if mode == 'train':
|
||||
|
||||
Dataset = DataLoadPreprocess
|
||||
self.training_samples = Dataset(
|
||||
config, mode, transform=transform, device=device)
|
||||
|
||||
if config.distributed:
|
||||
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self.training_samples)
|
||||
else:
|
||||
self.train_sampler = None
|
||||
|
||||
self.data = DataLoader(self.training_samples,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=(self.train_sampler is None),
|
||||
num_workers=config.workers,
|
||||
pin_memory=True,
|
||||
persistent_workers=True,
|
||||
# prefetch_factor=2,
|
||||
sampler=self.train_sampler)
|
||||
|
||||
elif mode == 'online_eval':
|
||||
self.testing_samples = DataLoadPreprocess(
|
||||
config, mode, transform=transform)
|
||||
if config.distributed: # redundant. here only for readability and to be more explicit
|
||||
# Give whole test set to all processes (and report evaluation only on one) regardless
|
||||
self.eval_sampler = None
|
||||
else:
|
||||
self.eval_sampler = None
|
||||
self.data = DataLoader(self.testing_samples, 1,
|
||||
shuffle=kwargs.get("shuffle_test", False),
|
||||
num_workers=1,
|
||||
pin_memory=False,
|
||||
sampler=self.eval_sampler)
|
||||
|
||||
elif mode == 'test':
|
||||
self.testing_samples = DataLoadPreprocess(
|
||||
config, mode, transform=transform)
|
||||
self.data = DataLoader(self.testing_samples,
|
||||
1, shuffle=False, num_workers=1)
|
||||
|
||||
else:
|
||||
print(
|
||||
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
||||
|
||||
|
||||
def repetitive_roundrobin(*iterables):
|
||||
"""
|
||||
cycles through iterables but sample wise
|
||||
first yield first sample from first iterable then first sample from second iterable and so on
|
||||
then second sample from first iterable then second sample from second iterable and so on
|
||||
|
||||
If one iterable is shorter than the others, it is repeated until all iterables are exhausted
|
||||
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
|
||||
"""
|
||||
# Repetitive roundrobin
|
||||
iterables_ = [iter(it) for it in iterables]
|
||||
exhausted = [False] * len(iterables)
|
||||
while not all(exhausted):
|
||||
for i, it in enumerate(iterables_):
|
||||
try:
|
||||
yield next(it)
|
||||
except StopIteration:
|
||||
exhausted[i] = True
|
||||
iterables_[i] = itertools.cycle(iterables[i])
|
||||
# First elements may get repeated if one iterable is shorter than the others
|
||||
yield next(iterables_[i])
|
||||
|
||||
|
||||
class RepetitiveRoundRobinDataLoader(object):
|
||||
def __init__(self, *dataloaders):
|
||||
self.dataloaders = dataloaders
|
||||
|
||||
def __iter__(self):
|
||||
return repetitive_roundrobin(*self.dataloaders)
|
||||
|
||||
def __len__(self):
|
||||
# First samples get repeated, thats why the plus one
|
||||
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
|
||||
|
||||
|
||||
class MixedNYUKITTI(object):
|
||||
def __init__(self, config, mode, device='cpu', **kwargs):
|
||||
config = edict(config)
|
||||
config.workers = config.workers // 2
|
||||
self.config = config
|
||||
nyu_conf = change_dataset(edict(config), 'nyu')
|
||||
kitti_conf = change_dataset(edict(config), 'kitti')
|
||||
|
||||
# make nyu default for testing
|
||||
self.config = config = nyu_conf
|
||||
img_size = self.config.get("img_size", None)
|
||||
img_size = img_size if self.config.get(
|
||||
"do_input_resize", False) else None
|
||||
if mode == 'train':
|
||||
nyu_loader = DepthDataLoader(
|
||||
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
||||
kitti_loader = DepthDataLoader(
|
||||
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
||||
# It has been changed to repetitive roundrobin
|
||||
self.data = RepetitiveRoundRobinDataLoader(
|
||||
nyu_loader, kitti_loader)
|
||||
else:
|
||||
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
||||
|
||||
|
||||
def remove_leading_slash(s):
|
||||
if s[0] == '/' or s[0] == '\\':
|
||||
return s[1:]
|
||||
return s
|
||||
|
||||
|
||||
class CachedReader:
|
||||
def __init__(self, shared_dict=None):
|
||||
if shared_dict:
|
||||
self._cache = shared_dict
|
||||
else:
|
||||
self._cache = {}
|
||||
|
||||
def open(self, fpath):
|
||||
im = self._cache.get(fpath, None)
|
||||
if im is None:
|
||||
im = self._cache[fpath] = Image.open(fpath)
|
||||
return im
|
||||
|
||||
|
||||
class ImReader:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# @cache
|
||||
def open(self, fpath):
|
||||
return Image.open(fpath)
|
||||
|
||||
|
||||
class DataLoadPreprocess(Dataset):
|
||||
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
|
||||
self.config = config
|
||||
if mode == 'online_eval':
|
||||
with open(config.filenames_file_eval, 'r') as f:
|
||||
self.filenames = f.readlines()
|
||||
else:
|
||||
with open(config.filenames_file, 'r') as f:
|
||||
self.filenames = f.readlines()
|
||||
|
||||
self.mode = mode
|
||||
self.transform = transform
|
||||
self.to_tensor = ToTensor(mode)
|
||||
self.is_for_online_eval = is_for_online_eval
|
||||
if config.use_shared_dict:
|
||||
self.reader = CachedReader(config.shared_dict)
|
||||
else:
|
||||
self.reader = ImReader()
|
||||
|
||||
def postprocess(self, sample):
|
||||
return sample
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample_path = self.filenames[idx]
|
||||
focal = float(sample_path.split()[2])
|
||||
sample = {}
|
||||
|
||||
if self.mode == 'train':
|
||||
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
|
||||
image_path = os.path.join(
|
||||
self.config.data_path, remove_leading_slash(sample_path.split()[3]))
|
||||
depth_path = os.path.join(
|
||||
self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
|
||||
else:
|
||||
image_path = os.path.join(
|
||||
self.config.data_path, remove_leading_slash(sample_path.split()[0]))
|
||||
depth_path = os.path.join(
|
||||
self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
|
||||
|
||||
image = self.reader.open(image_path)
|
||||
depth_gt = self.reader.open(depth_path)
|
||||
w, h = image.size
|
||||
|
||||
if self.config.do_kb_crop:
|
||||
height = image.height
|
||||
width = image.width
|
||||
top_margin = int(height - 352)
|
||||
left_margin = int((width - 1216) / 2)
|
||||
depth_gt = depth_gt.crop(
|
||||
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
||||
image = image.crop(
|
||||
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
||||
|
||||
# Avoid blank boundaries due to pixel registration?
|
||||
# Train images have white border. Test images have black border.
|
||||
if self.config.dataset == 'nyu' and self.config.avoid_boundary:
|
||||
# print("Avoiding Blank Boundaries!")
|
||||
# We just crop and pad again with reflect padding to original size
|
||||
# original_size = image.size
|
||||
crop_params = get_white_border(np.array(image, dtype=np.uint8))
|
||||
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
||||
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
||||
|
||||
# Use reflect padding to fill the blank
|
||||
image = np.array(image)
|
||||
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
|
||||
image = Image.fromarray(image)
|
||||
|
||||
depth_gt = np.array(depth_gt)
|
||||
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
|
||||
depth_gt = Image.fromarray(depth_gt)
|
||||
|
||||
|
||||
if self.config.do_random_rotate and (self.config.aug):
|
||||
random_angle = (random.random() - 0.5) * 2 * self.config.degree
|
||||
image = self.rotate_image(image, random_angle)
|
||||
depth_gt = self.rotate_image(
|
||||
depth_gt, random_angle, flag=Image.NEAREST)
|
||||
|
||||
image = np.asarray(image, dtype=np.float32) / 255.0
|
||||
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
||||
depth_gt = np.expand_dims(depth_gt, axis=2)
|
||||
|
||||
if self.config.dataset == 'nyu':
|
||||
depth_gt = depth_gt / 1000.0
|
||||
else:
|
||||
depth_gt = depth_gt / 256.0
|
||||
|
||||
if self.config.aug and (self.config.random_crop):
|
||||
image, depth_gt = self.random_crop(
|
||||
image, depth_gt, self.config.input_height, self.config.input_width)
|
||||
|
||||
if self.config.aug and self.config.random_translate:
|
||||
# print("Random Translation!")
|
||||
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
|
||||
|
||||
image, depth_gt = self.train_preprocess(image, depth_gt)
|
||||
mask = np.logical_and(depth_gt > self.config.min_depth,
|
||||
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
||||
sample = {'image': image, 'depth': depth_gt, 'focal': focal,
|
||||
'mask': mask, **sample}
|
||||
|
||||
else:
|
||||
if self.mode == 'online_eval':
|
||||
data_path = self.config.data_path_eval
|
||||
else:
|
||||
data_path = self.config.data_path
|
||||
|
||||
image_path = os.path.join(
|
||||
data_path, remove_leading_slash(sample_path.split()[0]))
|
||||
image = np.asarray(self.reader.open(image_path),
|
||||
dtype=np.float32) / 255.0
|
||||
|
||||
if self.mode == 'online_eval':
|
||||
gt_path = self.config.gt_path_eval
|
||||
depth_path = os.path.join(
|
||||
gt_path, remove_leading_slash(sample_path.split()[1]))
|
||||
has_valid_depth = False
|
||||
try:
|
||||
depth_gt = self.reader.open(depth_path)
|
||||
has_valid_depth = True
|
||||
except IOError:
|
||||
depth_gt = False
|
||||
# print('Missing gt for {}'.format(image_path))
|
||||
|
||||
if has_valid_depth:
|
||||
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
||||
depth_gt = np.expand_dims(depth_gt, axis=2)
|
||||
if self.config.dataset == 'nyu':
|
||||
depth_gt = depth_gt / 1000.0
|
||||
else:
|
||||
depth_gt = depth_gt / 256.0
|
||||
|
||||
mask = np.logical_and(
|
||||
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
|
||||
else:
|
||||
mask = False
|
||||
|
||||
if self.config.do_kb_crop:
|
||||
height = image.shape[0]
|
||||
width = image.shape[1]
|
||||
top_margin = int(height - 352)
|
||||
left_margin = int((width - 1216) / 2)
|
||||
image = image[top_margin:top_margin + 352,
|
||||
left_margin:left_margin + 1216, :]
|
||||
if self.mode == 'online_eval' and has_valid_depth:
|
||||
depth_gt = depth_gt[top_margin:top_margin +
|
||||
352, left_margin:left_margin + 1216, :]
|
||||
|
||||
if self.mode == 'online_eval':
|
||||
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
||||
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
|
||||
'mask': mask}
|
||||
else:
|
||||
sample = {'image': image, 'focal': focal}
|
||||
|
||||
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
|
||||
mask = np.logical_and(depth_gt > self.config.min_depth,
|
||||
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
||||
sample['mask'] = mask
|
||||
|
||||
if self.transform:
|
||||
sample = self.transform(sample)
|
||||
|
||||
sample = self.postprocess(sample)
|
||||
sample['dataset'] = self.config.dataset
|
||||
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
|
||||
|
||||
return sample
|
||||
|
||||
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
||||
result = image.rotate(angle, resample=flag)
|
||||
return result
|
||||
|
||||
def random_crop(self, img, depth, height, width):
|
||||
assert img.shape[0] >= height
|
||||
assert img.shape[1] >= width
|
||||
assert img.shape[0] == depth.shape[0]
|
||||
assert img.shape[1] == depth.shape[1]
|
||||
x = random.randint(0, img.shape[1] - width)
|
||||
y = random.randint(0, img.shape[0] - height)
|
||||
img = img[y:y + height, x:x + width, :]
|
||||
depth = depth[y:y + height, x:x + width, :]
|
||||
|
||||
return img, depth
|
||||
|
||||
def random_translate(self, img, depth, max_t=20):
|
||||
assert img.shape[0] == depth.shape[0]
|
||||
assert img.shape[1] == depth.shape[1]
|
||||
p = self.config.translate_prob
|
||||
do_translate = random.random()
|
||||
if do_translate > p:
|
||||
return img, depth
|
||||
x = random.randint(-max_t, max_t)
|
||||
y = random.randint(-max_t, max_t)
|
||||
M = np.float32([[1, 0, x], [0, 1, y]])
|
||||
# print(img.shape, depth.shape)
|
||||
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
||||
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
|
||||
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
|
||||
# print("after", img.shape, depth.shape)
|
||||
return img, depth
|
||||
|
||||
def train_preprocess(self, image, depth_gt):
|
||||
if self.config.aug:
|
||||
# Random flipping
|
||||
do_flip = random.random()
|
||||
if do_flip > 0.5:
|
||||
image = (image[:, ::-1, :]).copy()
|
||||
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
||||
|
||||
# Random gamma, brightness, color augmentation
|
||||
do_augment = random.random()
|
||||
if do_augment > 0.5:
|
||||
image = self.augment_image(image)
|
||||
|
||||
return image, depth_gt
|
||||
|
||||
def augment_image(self, image):
|
||||
# gamma augmentation
|
||||
gamma = random.uniform(0.9, 1.1)
|
||||
image_aug = image ** gamma
|
||||
|
||||
# brightness augmentation
|
||||
if self.config.dataset == 'nyu':
|
||||
brightness = random.uniform(0.75, 1.25)
|
||||
else:
|
||||
brightness = random.uniform(0.9, 1.1)
|
||||
image_aug = image_aug * brightness
|
||||
|
||||
# color augmentation
|
||||
colors = np.random.uniform(0.9, 1.1, size=3)
|
||||
white = np.ones((image.shape[0], image.shape[1]))
|
||||
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
||||
image_aug *= color_image
|
||||
image_aug = np.clip(image_aug, 0, 1)
|
||||
|
||||
return image_aug
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
|
||||
class ToTensor(object):
|
||||
def __init__(self, mode, do_normalize=False, size=None):
|
||||
self.mode = mode
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
|
||||
self.size = size
|
||||
if size is not None:
|
||||
self.resize = transforms.Resize(size=size)
|
||||
else:
|
||||
self.resize = nn.Identity()
|
||||
|
||||
def __call__(self, sample):
|
||||
image, focal = sample['image'], sample['focal']
|
||||
image = self.to_tensor(image)
|
||||
image = self.normalize(image)
|
||||
image = self.resize(image)
|
||||
|
||||
if self.mode == 'test':
|
||||
return {'image': image, 'focal': focal}
|
||||
|
||||
depth = sample['depth']
|
||||
if self.mode == 'train':
|
||||
depth = self.to_tensor(depth)
|
||||
return {**sample, 'image': image, 'depth': depth, 'focal': focal}
|
||||
else:
|
||||
has_valid_depth = sample['has_valid_depth']
|
||||
image = self.resize(image)
|
||||
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
||||
'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
|
||||
|
||||
def to_tensor(self, pic):
|
||||
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
||||
raise TypeError(
|
||||
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
||||
|
||||
if isinstance(pic, np.ndarray):
|
||||
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
||||
return img
|
||||
|
||||
# handle PIL Image
|
||||
if pic.mode == 'I':
|
||||
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
||||
elif pic.mode == 'I;16':
|
||||
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
||||
else:
|
||||
img = torch.ByteTensor(
|
||||
torch.ByteStorage.from_buffer(pic.tobytes()))
|
||||
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
||||
if pic.mode == 'YCbCr':
|
||||
nchannel = 3
|
||||
elif pic.mode == 'I;16':
|
||||
nchannel = 1
|
||||
else:
|
||||
nchannel = len(pic.mode)
|
||||
img = img.view(pic.size[1], pic.size[0], nchannel)
|
||||
|
||||
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
||||
if isinstance(img, torch.ByteTensor):
|
||||
return img.float()
|
||||
else:
|
||||
return img
|
||||
Reference in New Issue
Block a user