mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-01-26 15:29:46 +00:00
161 lines
6.3 KiB
Python
161 lines
6.3 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2022 Intelligent Systems Lab Org
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# File author: Shariq Farooq Bhat
|
|
|
|
import argparse
|
|
from pprint import pprint
|
|
|
|
import torch
|
|
from zoedepth.utils.easydict import EasyDict as edict
|
|
from tqdm import tqdm
|
|
|
|
from zoedepth.data.data_mono import DepthDataLoader
|
|
from zoedepth.models.builder import build_model
|
|
from zoedepth.utils.arg_utils import parse_unknown
|
|
from zoedepth.utils.config import change_dataset, get_config, ALL_EVAL_DATASETS, ALL_INDOOR, ALL_OUTDOOR
|
|
from zoedepth.utils.misc import (RunningAverageDict, colors, compute_metrics,
|
|
count_parameters)
|
|
|
|
|
|
@torch.no_grad()
|
|
def infer(model, images, **kwargs):
|
|
"""Inference with flip augmentation"""
|
|
# images.shape = N, C, H, W
|
|
def get_depth_from_prediction(pred):
|
|
if isinstance(pred, torch.Tensor):
|
|
pred = pred # pass
|
|
elif isinstance(pred, (list, tuple)):
|
|
pred = pred[-1]
|
|
elif isinstance(pred, dict):
|
|
pred = pred['metric_depth'] if 'metric_depth' in pred else pred['out']
|
|
else:
|
|
raise NotImplementedError(f"Unknown output type {type(pred)}")
|
|
return pred
|
|
|
|
pred1 = model(images, **kwargs)
|
|
pred1 = get_depth_from_prediction(pred1)
|
|
|
|
pred2 = model(torch.flip(images, [3]), **kwargs)
|
|
pred2 = get_depth_from_prediction(pred2)
|
|
pred2 = torch.flip(pred2, [3])
|
|
|
|
mean_pred = 0.5 * (pred1 + pred2)
|
|
|
|
return mean_pred
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, test_loader, config, round_vals=True, round_precision=3):
|
|
model.eval()
|
|
metrics = RunningAverageDict()
|
|
for i, sample in tqdm(enumerate(test_loader), total=len(test_loader)):
|
|
if 'has_valid_depth' in sample:
|
|
if not sample['has_valid_depth']:
|
|
continue
|
|
image, depth = sample['image'], sample['depth']
|
|
image, depth = image.cuda(), depth.cuda()
|
|
depth = depth.squeeze().unsqueeze(0).unsqueeze(0)
|
|
focal = sample.get('focal', torch.Tensor(
|
|
[715.0873]).cuda()) # This magic number (focal) is only used for evaluating BTS model
|
|
pred = infer(model, image, dataset=sample['dataset'][0], focal=focal)
|
|
|
|
# Save image, depth, pred for visualization
|
|
if "save_images" in config and config.save_images:
|
|
import os
|
|
# print("Saving images ...")
|
|
from PIL import Image
|
|
import torchvision.transforms as transforms
|
|
from zoedepth.utils.misc import colorize
|
|
|
|
os.makedirs(config.save_images, exist_ok=True)
|
|
# def save_image(img, path):
|
|
d = colorize(depth.squeeze().cpu().numpy(), 0, 10)
|
|
p = colorize(pred.squeeze().cpu().numpy(), 0, 10)
|
|
im = transforms.ToPILImage()(image.squeeze().cpu())
|
|
im.save(os.path.join(config.save_images, f"{i}_img.png"))
|
|
Image.fromarray(d).save(os.path.join(config.save_images, f"{i}_depth.png"))
|
|
Image.fromarray(p).save(os.path.join(config.save_images, f"{i}_pred.png"))
|
|
|
|
|
|
|
|
# print(depth.shape, pred.shape)
|
|
metrics.update(compute_metrics(depth, pred, config=config))
|
|
|
|
if round_vals:
|
|
def r(m): return round(m, round_precision)
|
|
else:
|
|
def r(m): return m
|
|
metrics = {k: r(v) for k, v in metrics.get_value().items()}
|
|
return metrics
|
|
|
|
def main(config):
|
|
model = build_model(config)
|
|
test_loader = DepthDataLoader(config, 'online_eval').data
|
|
model = model.cuda()
|
|
metrics = evaluate(model, test_loader, config)
|
|
print(f"{colors.fg.green}")
|
|
print(metrics)
|
|
print(f"{colors.reset}")
|
|
metrics['#params'] = f"{round(count_parameters(model, include_all=True)/1e6, 2)}M"
|
|
return metrics
|
|
|
|
|
|
def eval_model(model_name, pretrained_resource, dataset='nyu', **kwargs):
|
|
|
|
# Load default pretrained resource defined in config if not set
|
|
overwrite = {**kwargs, "pretrained_resource": pretrained_resource} if pretrained_resource else kwargs
|
|
config = get_config(model_name, "eval", dataset, **overwrite)
|
|
# config = change_dataset(config, dataset) # change the dataset
|
|
pprint(config)
|
|
print(f"Evaluating {model_name} on {dataset}...")
|
|
metrics = main(config)
|
|
return metrics
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--model", type=str,
|
|
required=True, help="Name of the model to evaluate")
|
|
parser.add_argument("-p", "--pretrained_resource", type=str,
|
|
required=False, default="", help="Pretrained resource to use for fetching weights. If not set, default resource from model config is used, Refer models.model_io.load_state_from_resource for more details.")
|
|
parser.add_argument("-d", "--dataset", type=str, required=False,
|
|
default='nyu', help="Dataset to evaluate on")
|
|
|
|
args, unknown_args = parser.parse_known_args()
|
|
overwrite_kwargs = parse_unknown(unknown_args)
|
|
|
|
if "ALL_INDOOR" in args.dataset:
|
|
datasets = ALL_INDOOR
|
|
elif "ALL_OUTDOOR" in args.dataset:
|
|
datasets = ALL_OUTDOOR
|
|
elif "ALL" in args.dataset:
|
|
datasets = ALL_EVAL_DATASETS
|
|
elif "," in args.dataset:
|
|
datasets = args.dataset.split(",")
|
|
else:
|
|
datasets = [args.dataset]
|
|
|
|
for dataset in datasets:
|
|
eval_model(args.model, pretrained_resource=args.pretrained_resource,
|
|
dataset=dataset, **overwrite_kwargs)
|