mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-04-30 12:21:13 +00:00
Load models from Hugging Face
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import FeatureFusionBlock, _make_scratch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
|
||||
from depth_anything.blocks import FeatureFusionBlock, _make_scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features, use_bn, size = None):
|
||||
@@ -164,7 +166,22 @@ class DPT_DINOv2(nn.Module):
|
||||
return depth.squeeze(1)
|
||||
|
||||
|
||||
class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
|
||||
def __init__(self, config):
|
||||
super().__init__(**config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
depth_anything = DPT_DINOv2()
|
||||
depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth'))
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="vits",
|
||||
type=str,
|
||||
choices=["vits", "vitb", "vitl"],
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
|
||||
|
||||
print(model)
|
||||
|
||||
Reference in New Issue
Block a user