mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-02-04 11:29:57 +00:00
574 lines
22 KiB
Python
574 lines
22 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2022 Intelligent Systems Lab Org
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# File author: Shariq Farooq Bhat
|
|
|
|
# This file is partly inspired from BTS (https://github.com/cleinc/bts/blob/master/pytorch/bts_dataloader.py); author: Jin Han Lee
|
|
|
|
import itertools
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.utils.data.distributed
|
|
from zoedepth.utils.easydict import EasyDict as edict
|
|
from PIL import Image, ImageOps
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision import transforms
|
|
|
|
from zoedepth.utils.config import change_dataset
|
|
|
|
from .ddad import get_ddad_loader
|
|
from .diml_indoor_test import get_diml_indoor_loader
|
|
from .diml_outdoor_test import get_diml_outdoor_loader
|
|
from .diode import get_diode_loader
|
|
from .hypersim import get_hypersim_loader
|
|
from .ibims import get_ibims_loader
|
|
from .sun_rgbd_loader import get_sunrgbd_loader
|
|
from .vkitti import get_vkitti_loader
|
|
from .vkitti2 import get_vkitti2_loader
|
|
|
|
from .preprocess import CropParams, get_white_border, get_black_border
|
|
|
|
|
|
def _is_pil_image(img):
|
|
return isinstance(img, Image.Image)
|
|
|
|
|
|
def _is_numpy_image(img):
|
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
|
|
|
|
|
def preprocessing_transforms(mode, **kwargs):
|
|
return transforms.Compose([
|
|
ToTensor(mode=mode, **kwargs)
|
|
])
|
|
|
|
|
|
class DepthDataLoader(object):
|
|
def __init__(self, config, mode, device='cpu', transform=None, **kwargs):
|
|
"""
|
|
Data loader for depth datasets
|
|
|
|
Args:
|
|
config (dict): Config dictionary. Refer to utils/config.py
|
|
mode (str): "train" or "online_eval"
|
|
device (str, optional): Device to load the data on. Defaults to 'cpu'.
|
|
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None.
|
|
"""
|
|
|
|
self.config = config
|
|
|
|
if config.dataset == 'ibims':
|
|
self.data = get_ibims_loader(config, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'sunrgbd':
|
|
self.data = get_sunrgbd_loader(
|
|
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'diml_indoor':
|
|
self.data = get_diml_indoor_loader(
|
|
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'diml_outdoor':
|
|
self.data = get_diml_outdoor_loader(
|
|
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if "diode" in config.dataset:
|
|
self.data = get_diode_loader(
|
|
config[config.dataset+"_root"], batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'hypersim_test':
|
|
self.data = get_hypersim_loader(
|
|
config.hypersim_test_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'vkitti':
|
|
self.data = get_vkitti_loader(
|
|
config.vkitti_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'vkitti2':
|
|
self.data = get_vkitti2_loader(
|
|
config.vkitti2_root, batch_size=1, num_workers=1)
|
|
return
|
|
|
|
if config.dataset == 'ddad':
|
|
self.data = get_ddad_loader(config.ddad_root, resize_shape=(
|
|
352, 1216), batch_size=1, num_workers=1)
|
|
return
|
|
|
|
img_size = self.config.get("img_size", None)
|
|
img_size = img_size if self.config.get(
|
|
"do_input_resize", False) else None
|
|
|
|
if transform is None:
|
|
transform = preprocessing_transforms(mode, size=img_size)
|
|
|
|
if mode == 'train':
|
|
|
|
Dataset = DataLoadPreprocess
|
|
self.training_samples = Dataset(
|
|
config, mode, transform=transform, device=device)
|
|
|
|
if config.distributed:
|
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
self.training_samples)
|
|
else:
|
|
self.train_sampler = None
|
|
|
|
self.data = DataLoader(self.training_samples,
|
|
batch_size=config.batch_size,
|
|
shuffle=(self.train_sampler is None),
|
|
num_workers=config.workers,
|
|
pin_memory=True,
|
|
persistent_workers=True,
|
|
# prefetch_factor=2,
|
|
sampler=self.train_sampler)
|
|
|
|
elif mode == 'online_eval':
|
|
self.testing_samples = DataLoadPreprocess(
|
|
config, mode, transform=transform)
|
|
if config.distributed: # redundant. here only for readability and to be more explicit
|
|
# Give whole test set to all processes (and report evaluation only on one) regardless
|
|
self.eval_sampler = None
|
|
else:
|
|
self.eval_sampler = None
|
|
self.data = DataLoader(self.testing_samples, 1,
|
|
shuffle=kwargs.get("shuffle_test", False),
|
|
num_workers=1,
|
|
pin_memory=False,
|
|
sampler=self.eval_sampler)
|
|
|
|
elif mode == 'test':
|
|
self.testing_samples = DataLoadPreprocess(
|
|
config, mode, transform=transform)
|
|
self.data = DataLoader(self.testing_samples,
|
|
1, shuffle=False, num_workers=1)
|
|
|
|
else:
|
|
print(
|
|
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
|
|
|
|
|
|
def repetitive_roundrobin(*iterables):
|
|
"""
|
|
cycles through iterables but sample wise
|
|
first yield first sample from first iterable then first sample from second iterable and so on
|
|
then second sample from first iterable then second sample from second iterable and so on
|
|
|
|
If one iterable is shorter than the others, it is repeated until all iterables are exhausted
|
|
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E
|
|
"""
|
|
# Repetitive roundrobin
|
|
iterables_ = [iter(it) for it in iterables]
|
|
exhausted = [False] * len(iterables)
|
|
while not all(exhausted):
|
|
for i, it in enumerate(iterables_):
|
|
try:
|
|
yield next(it)
|
|
except StopIteration:
|
|
exhausted[i] = True
|
|
iterables_[i] = itertools.cycle(iterables[i])
|
|
# First elements may get repeated if one iterable is shorter than the others
|
|
yield next(iterables_[i])
|
|
|
|
|
|
class RepetitiveRoundRobinDataLoader(object):
|
|
def __init__(self, *dataloaders):
|
|
self.dataloaders = dataloaders
|
|
|
|
def __iter__(self):
|
|
return repetitive_roundrobin(*self.dataloaders)
|
|
|
|
def __len__(self):
|
|
# First samples get repeated, thats why the plus one
|
|
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1)
|
|
|
|
|
|
class MixedNYUKITTI(object):
|
|
def __init__(self, config, mode, device='cpu', **kwargs):
|
|
config = edict(config)
|
|
config.workers = config.workers // 2
|
|
self.config = config
|
|
nyu_conf = change_dataset(edict(config), 'nyu')
|
|
kitti_conf = change_dataset(edict(config), 'kitti')
|
|
|
|
# make nyu default for testing
|
|
self.config = config = nyu_conf
|
|
img_size = self.config.get("img_size", None)
|
|
img_size = img_size if self.config.get(
|
|
"do_input_resize", False) else None
|
|
if mode == 'train':
|
|
nyu_loader = DepthDataLoader(
|
|
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
|
kitti_loader = DepthDataLoader(
|
|
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data
|
|
# It has been changed to repetitive roundrobin
|
|
self.data = RepetitiveRoundRobinDataLoader(
|
|
nyu_loader, kitti_loader)
|
|
else:
|
|
self.data = DepthDataLoader(nyu_conf, mode, device=device).data
|
|
|
|
|
|
def remove_leading_slash(s):
|
|
if s[0] == '/' or s[0] == '\\':
|
|
return s[1:]
|
|
return s
|
|
|
|
|
|
class CachedReader:
|
|
def __init__(self, shared_dict=None):
|
|
if shared_dict:
|
|
self._cache = shared_dict
|
|
else:
|
|
self._cache = {}
|
|
|
|
def open(self, fpath):
|
|
im = self._cache.get(fpath, None)
|
|
if im is None:
|
|
im = self._cache[fpath] = Image.open(fpath)
|
|
return im
|
|
|
|
|
|
class ImReader:
|
|
def __init__(self):
|
|
pass
|
|
|
|
# @cache
|
|
def open(self, fpath):
|
|
return Image.open(fpath)
|
|
|
|
|
|
class DataLoadPreprocess(Dataset):
|
|
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs):
|
|
self.config = config
|
|
if mode == 'online_eval':
|
|
with open(config.filenames_file_eval, 'r') as f:
|
|
self.filenames = f.readlines()
|
|
else:
|
|
with open(config.filenames_file, 'r') as f:
|
|
self.filenames = f.readlines()
|
|
|
|
self.mode = mode
|
|
self.transform = transform
|
|
self.to_tensor = ToTensor(mode)
|
|
self.is_for_online_eval = is_for_online_eval
|
|
if config.use_shared_dict:
|
|
self.reader = CachedReader(config.shared_dict)
|
|
else:
|
|
self.reader = ImReader()
|
|
|
|
def postprocess(self, sample):
|
|
return sample
|
|
|
|
def __getitem__(self, idx):
|
|
sample_path = self.filenames[idx]
|
|
focal = float(sample_path.split()[2])
|
|
sample = {}
|
|
|
|
if self.mode == 'train':
|
|
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5:
|
|
image_path = os.path.join(
|
|
self.config.data_path, remove_leading_slash(sample_path.split()[3]))
|
|
depth_path = os.path.join(
|
|
self.config.gt_path, remove_leading_slash(sample_path.split()[4]))
|
|
else:
|
|
image_path = os.path.join(
|
|
self.config.data_path, remove_leading_slash(sample_path.split()[0]))
|
|
depth_path = os.path.join(
|
|
self.config.gt_path, remove_leading_slash(sample_path.split()[1]))
|
|
|
|
image = self.reader.open(image_path)
|
|
depth_gt = self.reader.open(depth_path)
|
|
w, h = image.size
|
|
|
|
if self.config.do_kb_crop:
|
|
height = image.height
|
|
width = image.width
|
|
top_margin = int(height - 352)
|
|
left_margin = int((width - 1216) / 2)
|
|
depth_gt = depth_gt.crop(
|
|
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
|
image = image.crop(
|
|
(left_margin, top_margin, left_margin + 1216, top_margin + 352))
|
|
|
|
# Avoid blank boundaries due to pixel registration?
|
|
# Train images have white border. Test images have black border.
|
|
if self.config.dataset == 'nyu' and self.config.avoid_boundary:
|
|
# print("Avoiding Blank Boundaries!")
|
|
# We just crop and pad again with reflect padding to original size
|
|
# original_size = image.size
|
|
crop_params = get_white_border(np.array(image, dtype=np.uint8))
|
|
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
|
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom))
|
|
|
|
# Use reflect padding to fill the blank
|
|
image = np.array(image)
|
|
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect')
|
|
image = Image.fromarray(image)
|
|
|
|
depth_gt = np.array(depth_gt)
|
|
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0)
|
|
depth_gt = Image.fromarray(depth_gt)
|
|
|
|
|
|
if self.config.do_random_rotate and (self.config.aug):
|
|
random_angle = (random.random() - 0.5) * 2 * self.config.degree
|
|
image = self.rotate_image(image, random_angle)
|
|
depth_gt = self.rotate_image(
|
|
depth_gt, random_angle, flag=Image.NEAREST)
|
|
|
|
image = np.asarray(image, dtype=np.float32) / 255.0
|
|
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
|
depth_gt = np.expand_dims(depth_gt, axis=2)
|
|
|
|
if self.config.dataset == 'nyu':
|
|
depth_gt = depth_gt / 1000.0
|
|
else:
|
|
depth_gt = depth_gt / 256.0
|
|
|
|
if self.config.aug and (self.config.random_crop):
|
|
image, depth_gt = self.random_crop(
|
|
image, depth_gt, self.config.input_height, self.config.input_width)
|
|
|
|
if self.config.aug and self.config.random_translate:
|
|
# print("Random Translation!")
|
|
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation)
|
|
|
|
image, depth_gt = self.train_preprocess(image, depth_gt)
|
|
mask = np.logical_and(depth_gt > self.config.min_depth,
|
|
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal,
|
|
'mask': mask, **sample}
|
|
|
|
else:
|
|
if self.mode == 'online_eval':
|
|
data_path = self.config.data_path_eval
|
|
else:
|
|
data_path = self.config.data_path
|
|
|
|
image_path = os.path.join(
|
|
data_path, remove_leading_slash(sample_path.split()[0]))
|
|
image = np.asarray(self.reader.open(image_path),
|
|
dtype=np.float32) / 255.0
|
|
|
|
if self.mode == 'online_eval':
|
|
gt_path = self.config.gt_path_eval
|
|
depth_path = os.path.join(
|
|
gt_path, remove_leading_slash(sample_path.split()[1]))
|
|
has_valid_depth = False
|
|
try:
|
|
depth_gt = self.reader.open(depth_path)
|
|
has_valid_depth = True
|
|
except IOError:
|
|
depth_gt = False
|
|
# print('Missing gt for {}'.format(image_path))
|
|
|
|
if has_valid_depth:
|
|
depth_gt = np.asarray(depth_gt, dtype=np.float32)
|
|
depth_gt = np.expand_dims(depth_gt, axis=2)
|
|
if self.config.dataset == 'nyu':
|
|
depth_gt = depth_gt / 1000.0
|
|
else:
|
|
depth_gt = depth_gt / 256.0
|
|
|
|
mask = np.logical_and(
|
|
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...]
|
|
else:
|
|
mask = False
|
|
|
|
if self.config.do_kb_crop:
|
|
height = image.shape[0]
|
|
width = image.shape[1]
|
|
top_margin = int(height - 352)
|
|
left_margin = int((width - 1216) / 2)
|
|
image = image[top_margin:top_margin + 352,
|
|
left_margin:left_margin + 1216, :]
|
|
if self.mode == 'online_eval' and has_valid_depth:
|
|
depth_gt = depth_gt[top_margin:top_margin +
|
|
352, left_margin:left_margin + 1216, :]
|
|
|
|
if self.mode == 'online_eval':
|
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
|
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1],
|
|
'mask': mask}
|
|
else:
|
|
sample = {'image': image, 'focal': focal}
|
|
|
|
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']):
|
|
mask = np.logical_and(depth_gt > self.config.min_depth,
|
|
depth_gt < self.config.max_depth).squeeze()[None, ...]
|
|
sample['mask'] = mask
|
|
|
|
if self.transform:
|
|
sample = self.transform(sample)
|
|
|
|
sample = self.postprocess(sample)
|
|
sample['dataset'] = self.config.dataset
|
|
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]}
|
|
|
|
return sample
|
|
|
|
def rotate_image(self, image, angle, flag=Image.BILINEAR):
|
|
result = image.rotate(angle, resample=flag)
|
|
return result
|
|
|
|
def random_crop(self, img, depth, height, width):
|
|
assert img.shape[0] >= height
|
|
assert img.shape[1] >= width
|
|
assert img.shape[0] == depth.shape[0]
|
|
assert img.shape[1] == depth.shape[1]
|
|
x = random.randint(0, img.shape[1] - width)
|
|
y = random.randint(0, img.shape[0] - height)
|
|
img = img[y:y + height, x:x + width, :]
|
|
depth = depth[y:y + height, x:x + width, :]
|
|
|
|
return img, depth
|
|
|
|
def random_translate(self, img, depth, max_t=20):
|
|
assert img.shape[0] == depth.shape[0]
|
|
assert img.shape[1] == depth.shape[1]
|
|
p = self.config.translate_prob
|
|
do_translate = random.random()
|
|
if do_translate > p:
|
|
return img, depth
|
|
x = random.randint(-max_t, max_t)
|
|
y = random.randint(-max_t, max_t)
|
|
M = np.float32([[1, 0, x], [0, 1, y]])
|
|
# print(img.shape, depth.shape)
|
|
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
|
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0]))
|
|
depth = depth.squeeze()[..., None] # add channel dim back. Affine warp removes it
|
|
# print("after", img.shape, depth.shape)
|
|
return img, depth
|
|
|
|
def train_preprocess(self, image, depth_gt):
|
|
if self.config.aug:
|
|
# Random flipping
|
|
do_flip = random.random()
|
|
if do_flip > 0.5:
|
|
image = (image[:, ::-1, :]).copy()
|
|
depth_gt = (depth_gt[:, ::-1, :]).copy()
|
|
|
|
# Random gamma, brightness, color augmentation
|
|
do_augment = random.random()
|
|
if do_augment > 0.5:
|
|
image = self.augment_image(image)
|
|
|
|
return image, depth_gt
|
|
|
|
def augment_image(self, image):
|
|
# gamma augmentation
|
|
gamma = random.uniform(0.9, 1.1)
|
|
image_aug = image ** gamma
|
|
|
|
# brightness augmentation
|
|
if self.config.dataset == 'nyu':
|
|
brightness = random.uniform(0.75, 1.25)
|
|
else:
|
|
brightness = random.uniform(0.9, 1.1)
|
|
image_aug = image_aug * brightness
|
|
|
|
# color augmentation
|
|
colors = np.random.uniform(0.9, 1.1, size=3)
|
|
white = np.ones((image.shape[0], image.shape[1]))
|
|
color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
|
|
image_aug *= color_image
|
|
image_aug = np.clip(image_aug, 0, 1)
|
|
|
|
return image_aug
|
|
|
|
def __len__(self):
|
|
return len(self.filenames)
|
|
|
|
|
|
class ToTensor(object):
|
|
def __init__(self, mode, do_normalize=False, size=None):
|
|
self.mode = mode
|
|
self.normalize = transforms.Normalize(
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity()
|
|
self.size = size
|
|
if size is not None:
|
|
self.resize = transforms.Resize(size=size)
|
|
else:
|
|
self.resize = nn.Identity()
|
|
|
|
def __call__(self, sample):
|
|
image, focal = sample['image'], sample['focal']
|
|
image = self.to_tensor(image)
|
|
image = self.normalize(image)
|
|
image = self.resize(image)
|
|
|
|
if self.mode == 'test':
|
|
return {'image': image, 'focal': focal}
|
|
|
|
depth = sample['depth']
|
|
if self.mode == 'train':
|
|
depth = self.to_tensor(depth)
|
|
return {**sample, 'image': image, 'depth': depth, 'focal': focal}
|
|
else:
|
|
has_valid_depth = sample['has_valid_depth']
|
|
image = self.resize(image)
|
|
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth,
|
|
'image_path': sample['image_path'], 'depth_path': sample['depth_path']}
|
|
|
|
def to_tensor(self, pic):
|
|
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
|
|
raise TypeError(
|
|
'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
|
|
|
|
if isinstance(pic, np.ndarray):
|
|
img = torch.from_numpy(pic.transpose((2, 0, 1)))
|
|
return img
|
|
|
|
# handle PIL Image
|
|
if pic.mode == 'I':
|
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
|
|
elif pic.mode == 'I;16':
|
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
|
|
else:
|
|
img = torch.ByteTensor(
|
|
torch.ByteStorage.from_buffer(pic.tobytes()))
|
|
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
|
|
if pic.mode == 'YCbCr':
|
|
nchannel = 3
|
|
elif pic.mode == 'I;16':
|
|
nchannel = 1
|
|
else:
|
|
nchannel = len(pic.mode)
|
|
img = img.view(pic.size[1], pic.size[0], nchannel)
|
|
|
|
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
|
if isinstance(img, torch.ByteTensor):
|
|
return img.float()
|
|
else:
|
|
return img
|