Load models from Hugging Face

This commit is contained in:
Lihe Yang
2024-01-23 17:25:15 +08:00
committed by GitHub
parent a61bb5af0e
commit c3390b83bb
5 changed files with 45 additions and 35 deletions

View File

@@ -1,8 +1,10 @@
import argparse
import torch
import torch.nn as nn
from .blocks import FeatureFusionBlock, _make_scratch
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
from depth_anything.blocks import FeatureFusionBlock, _make_scratch
def _make_fusion_block(features, use_bn, size = None):
@@ -164,7 +166,22 @@ class DPT_DINOv2(nn.Module):
return depth.squeeze(1)
class DepthAnything(DPT_DINOv2, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__(**config)
if __name__ == '__main__':
depth_anything = DPT_DINOv2()
depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth'))
parser = argparse.ArgumentParser()
parser.add_argument(
"--encoder",
default="vits",
type=str,
choices=["vits", "vitb", "vitl"],
)
args = parser.parse_args()
model = DepthAnything.from_pretrained("LiheYoung/depth_anything_{:}14".format(args.encoder))
print(model)