diff --git a/README.md b/README.md index c070a36..a6bfa09 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ We highlight the **best** and *second best* results in **bold** and *italic* res ## Pre-trained models -We provide three models of varying scales for robust relatve depth estimation: +We provide three models of varying scales for robust relative depth estimation: - Depth-Anything-ViT-Small (24.8M) @@ -67,14 +67,18 @@ We provide three models of varying scales for robust relatve depth estimation: - Depth-Anything-ViT-Large (335.3M) -Download our pre-trained models [here](https://huggingface.co/spaces/LiheYoung/Depth-Anything/tree/main/checkpoints), and put them under the ``checkpoints`` directory. +You can easily load our pre-trained models by: +```python +from depth_anything.dpt import DepthAnything + +encoder = 'vits' # can also be 'vitb' or 'vitl' +depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{:}14'.format(encoder)) +``` ## Usage ### Installation -The setup is very simple. Just make ensure ``torch``, ``torchvision``, and ``cv2`` are supported in your environment. - ```bash git clone https://github.com/LiheYoung/Depth-Anything cd Depth-Anything @@ -84,13 +88,13 @@ pip install -r requirements.txt ### Running ```bash -python run.py --encoder --load-from --img-path --outdir --localhub +python run.py --encoder --img-path --outdir ``` For the ``img-path``, you can either 1) point it to an image directory storing all interested images, 2) point it to a single image, or 3) point it to a text file storing all image paths. For example: ```bash -python run.py --encoder vitl --load-from checkpoints/depth_anything_vitl14.pth --img-path demo_images --outdir depth_visualization --localhub +python run.py --encoder vitl --img-path demo_images --outdir depth_visualization ``` @@ -112,14 +116,14 @@ If you want to use Depth Anything in your own project, you can simply follow [`` Code snippet (note the difference between our data pre-processing and that of MiDaS) ```python -from depth_anything.dpt import DPT_DINOv2 +from depth_anything.dpt import DepthAnything from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet import cv2 import torch -depth_anything = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], localhub=True) -depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_vitl14.pth')) +encoder = 'vits' # can also be 'vitb' or 'vitl' +depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{:}14'.format(encoder)) transform = Compose([ Resize( diff --git a/app.py b/app.py index 11c708f..8e730f9 100644 --- a/app.py +++ b/app.py @@ -9,7 +9,7 @@ from torchvision.transforms import Compose import tempfile from gradio_imageslider import ImageSlider -from depth_anything.dpt import DPT_DINOv2 +from depth_anything.dpt import DepthAnything from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet css = """ @@ -24,8 +24,7 @@ css = """ } """ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' -model = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]).to(DEVICE).eval() -model.load_state_dict(torch.load('checkpoints/depth_anything_vitl14.pth')) +model = DepthAnything.from_pretrained('LiheYoung/depth_anything_vitl14').to(DEVICE).eval() title = "# Depth Anything" description = """Official demo for **Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data**. @@ -49,7 +48,6 @@ transform = Compose([ def predict_depth(model, image): return model(image) - with gr.Blocks(css=css) as demo: gr.Markdown(title) gr.Markdown(description) @@ -93,4 +91,4 @@ with gr.Blocks(css=css) as demo: if __name__ == '__main__': - demo.queue().launch() + demo.queue().launch() \ No newline at end of file diff --git a/depth_anything/dpt.py b/depth_anything/dpt.py index aeb1d09..56b9545 100644 --- a/depth_anything/dpt.py +++ b/depth_anything/dpt.py @@ -1,8 +1,10 @@ +import argparse import torch import torch.nn as nn - -from .blocks import FeatureFusionBlock, _make_scratch import torch.nn.functional as F +from huggingface_hub import PyTorchModelHubMixin, hf_hub_download + +from depth_anything.blocks import FeatureFusionBlock, _make_scratch def _make_fusion_block(features, use_bn, size = None): @@ -164,7 +166,22 @@ class DPT_DINOv2(nn.Module): return depth.squeeze(1) +class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin): + def __init__(self, config): + super().__init__(**config) + + if __name__ == '__main__': - depth_anything = DPT_DINOv2() - depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth')) + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", + default="vits", + type=str, + choices=["vits", "vitb", "vitl"], + ) + args = parser.parse_args() + + model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder)) + + print(model) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0d750ee..4044895 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ gradio_imageslider gradio==4.14.0 torch torchvision -opencv-python \ No newline at end of file +opencv-python +huggingface_hub \ No newline at end of file diff --git a/run.py b/run.py index 67def6f..a0d07ae 100644 --- a/run.py +++ b/run.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from torchvision.transforms import Compose from tqdm import tqdm -from depth_anything.dpt import DPT_DINOv2 +from depth_anything.dpt import DepthAnything from depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet @@ -15,10 +15,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--img-path', type=str) parser.add_argument('--outdir', type=str, default='./vis_depth') - - parser.add_argument('--encoder', type=str, default='vitl') - parser.add_argument('--load-from', type=str, required=True) - parser.add_argument('--localhub', dest='localhub', action='store_true', default=False) + parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl']) args = parser.parse_args() @@ -29,19 +26,13 @@ if __name__ == '__main__': font_scale = 1 font_thickness = 2 - assert args.encoder in ['vits', 'vitb', 'vitl'] - if args.encoder == 'vits': - depth_anything = DPT_DINOv2(encoder='vits', features=64, out_channels=[48, 96, 192, 384], localhub=args.localhub).cuda() - elif args.encoder == 'vitb': - depth_anything = DPT_DINOv2(encoder='vitb', features=128, out_channels=[96, 192, 384, 768], localhub=args.localhub).cuda() - else: - depth_anything = DPT_DINOv2(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], localhub=args.localhub).cuda() + DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + depth_anything = DepthAnything.from_pretrained('LiheYoung/depth_anything_{}14'.format(args.encoder)).to(DEVICE) total_params = sum(param.numel() for param in depth_anything.parameters()) print('Total parameters: {:.2f}M'.format(total_params / 1e6)) - depth_anything.load_state_dict(torch.load(args.load_from, map_location='cpu'), strict=True) - depth_anything.eval() transform = Compose([ @@ -76,7 +67,7 @@ if __name__ == '__main__': h, w = image.shape[:2] image = transform({'image': image})['image'] - image = torch.from_numpy(image).unsqueeze(0).cuda() + image = torch.from_numpy(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): depth = depth_anything(image) @@ -109,4 +100,3 @@ if __name__ == '__main__': final_result = cv2.vconcat([caption_space, combined_results]) cv2.imwrite(os.path.join(args.outdir, filename[:filename.find('.')] + '_img_depth.png'), final_result) -