mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-01-26 15:29:46 +00:00
224 lines
6.6 KiB
Python
224 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from enum import Enum
|
|
from typing import Any, Callable, List, Optional, TypeVar
|
|
|
|
import torch
|
|
from torch.utils.data import Sampler
|
|
|
|
from .datasets import ImageNet, ImageNet22k
|
|
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
|
|
|
|
|
logger = logging.getLogger("dinov2")
|
|
|
|
|
|
class SamplerType(Enum):
|
|
DISTRIBUTED = 0
|
|
EPOCH = 1
|
|
INFINITE = 2
|
|
SHARDED_INFINITE = 3
|
|
SHARDED_INFINITE_NEW = 4
|
|
|
|
|
|
def _make_bool_str(b: bool) -> str:
|
|
return "yes" if b else "no"
|
|
|
|
|
|
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
|
def transform(sample):
|
|
image, target = sample
|
|
if image_transform is not None:
|
|
image = image_transform(image)
|
|
if target_transform is not None:
|
|
target = target_transform(target)
|
|
return image, target
|
|
|
|
return transform
|
|
|
|
|
|
def _parse_dataset_str(dataset_str: str):
|
|
tokens = dataset_str.split(":")
|
|
|
|
name = tokens[0]
|
|
kwargs = {}
|
|
|
|
for token in tokens[1:]:
|
|
key, value = token.split("=")
|
|
assert key in ("root", "extra", "split")
|
|
kwargs[key] = value
|
|
|
|
if name == "ImageNet":
|
|
class_ = ImageNet
|
|
if "split" in kwargs:
|
|
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
|
elif name == "ImageNet22k":
|
|
class_ = ImageNet22k
|
|
else:
|
|
raise ValueError(f'Unsupported dataset "{name}"')
|
|
|
|
return class_, kwargs
|
|
|
|
|
|
def make_dataset(
|
|
*,
|
|
dataset_str: str,
|
|
transform: Optional[Callable] = None,
|
|
target_transform: Optional[Callable] = None,
|
|
):
|
|
"""
|
|
Creates a dataset with the specified parameters.
|
|
|
|
Args:
|
|
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
|
transform: A transform to apply to images.
|
|
target_transform: A transform to apply to targets.
|
|
|
|
Returns:
|
|
The created dataset.
|
|
"""
|
|
logger.info(f'using dataset: "{dataset_str}"')
|
|
|
|
class_, kwargs = _parse_dataset_str(dataset_str)
|
|
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
|
|
|
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
|
|
|
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
|
if not hasattr(dataset, "transform"):
|
|
setattr(dataset, "transform", transform)
|
|
if not hasattr(dataset, "target_transform"):
|
|
setattr(dataset, "target_transform", target_transform)
|
|
|
|
return dataset
|
|
|
|
|
|
def _make_sampler(
|
|
*,
|
|
dataset,
|
|
type: Optional[SamplerType] = None,
|
|
shuffle: bool = False,
|
|
seed: int = 0,
|
|
size: int = -1,
|
|
advance: int = 0,
|
|
) -> Optional[Sampler]:
|
|
sample_count = len(dataset)
|
|
|
|
if type == SamplerType.INFINITE:
|
|
logger.info("sampler: infinite")
|
|
if size > 0:
|
|
raise ValueError("sampler size > 0 is invalid")
|
|
return InfiniteSampler(
|
|
sample_count=sample_count,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
advance=advance,
|
|
)
|
|
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
|
logger.info("sampler: sharded infinite")
|
|
if size > 0:
|
|
raise ValueError("sampler size > 0 is invalid")
|
|
# TODO: Remove support for old shuffling
|
|
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
|
return ShardedInfiniteSampler(
|
|
sample_count=sample_count,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
advance=advance,
|
|
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
|
)
|
|
elif type == SamplerType.EPOCH:
|
|
logger.info("sampler: epoch")
|
|
if advance > 0:
|
|
raise NotImplementedError("sampler advance > 0 is not supported")
|
|
size = size if size > 0 else sample_count
|
|
logger.info(f"# of samples / epoch: {size:,d}")
|
|
return EpochSampler(
|
|
size=size,
|
|
sample_count=sample_count,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
)
|
|
elif type == SamplerType.DISTRIBUTED:
|
|
logger.info("sampler: distributed")
|
|
if size > 0:
|
|
raise ValueError("sampler size > 0 is invalid")
|
|
if advance > 0:
|
|
raise ValueError("sampler advance > 0 is invalid")
|
|
return torch.utils.data.DistributedSampler(
|
|
dataset=dataset,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
drop_last=False,
|
|
)
|
|
|
|
logger.info("sampler: none")
|
|
return None
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
def make_data_loader(
|
|
*,
|
|
dataset,
|
|
batch_size: int,
|
|
num_workers: int,
|
|
shuffle: bool = True,
|
|
seed: int = 0,
|
|
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
|
sampler_size: int = -1,
|
|
sampler_advance: int = 0,
|
|
drop_last: bool = True,
|
|
persistent_workers: bool = False,
|
|
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
|
):
|
|
"""
|
|
Creates a data loader with the specified parameters.
|
|
|
|
Args:
|
|
dataset: A dataset (third party, LaViDa or WebDataset).
|
|
batch_size: The size of batches to generate.
|
|
num_workers: The number of workers to use.
|
|
shuffle: Whether to shuffle samples.
|
|
seed: The random seed to use.
|
|
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
|
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
|
sampler_advance: How many samples to skip (when applicable).
|
|
drop_last: Whether the last non-full batch of data should be dropped.
|
|
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
|
collate_fn: Function that performs batch collation
|
|
"""
|
|
|
|
sampler = _make_sampler(
|
|
dataset=dataset,
|
|
type=sampler_type,
|
|
shuffle=shuffle,
|
|
seed=seed,
|
|
size=sampler_size,
|
|
advance=sampler_advance,
|
|
)
|
|
|
|
logger.info("using PyTorch data loader")
|
|
data_loader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
sampler=sampler,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
drop_last=drop_last,
|
|
persistent_workers=persistent_workers,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
try:
|
|
logger.info(f"# of batches: {len(data_loader):,d}")
|
|
except TypeError: # data loader has no length
|
|
logger.info("infinite data loader")
|
|
return data_loader
|