mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-01-26 15:29:46 +00:00
48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import torch
|
|
from mmengine.model import BaseModule
|
|
from torch import nn
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
|
|
|
|
@MODELS.register_module()
|
|
class DINOv2(nn.Module):
|
|
"""Use DINOv2 pre-trained models
|
|
"""
|
|
|
|
def __init__(self, version='large', freeze=False, load_from=None):
|
|
super().__init__()
|
|
|
|
if version == 'large':
|
|
self.dinov2 = torch.hub.load('torchhub/facebookresearch_dinov2_main', 'dinov2_vitl14', source='local', pretrained=False)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if load_from is not None:
|
|
d = torch.load(load_from, map_location='cpu')
|
|
new_d = {}
|
|
for key, value in d.items():
|
|
if 'pretrained' in key:
|
|
new_d[key.replace('pretrained.', '')] = value
|
|
self.dinov2.load_state_dict(new_d)
|
|
|
|
self.freeze = freeze
|
|
|
|
def forward(self, inputs):
|
|
B, _, h, w = inputs.shape
|
|
|
|
if self.freeze:
|
|
with torch.no_grad():
|
|
features = self.dinov2.get_intermediate_layers(inputs, 4)
|
|
else:
|
|
features = self.dinov2.get_intermediate_layers(inputs, 4)
|
|
|
|
outs = []
|
|
for feature in features:
|
|
C = feature.shape[-1]
|
|
feature = feature.permute(0, 2, 1).reshape(B, C, h // 14, w // 14).contiguous()
|
|
outs.append(feature)
|
|
|
|
return outs
|