mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-05-01 04:41:13 +00:00
Initial commit
This commit is contained in:
176
metric_depth/train_mono.py
Normal file
176
metric_depth/train_mono.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2022 Intelligent Systems Lab Org
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# File author: Shariq Farooq Bhat
|
||||
|
||||
from zoedepth.utils.misc import count_parameters, parallelize
|
||||
from zoedepth.utils.config import get_config
|
||||
from zoedepth.utils.arg_utils import parse_unknown
|
||||
from zoedepth.trainers.builder import get_trainer
|
||||
from zoedepth.models.builder import build_model
|
||||
from zoedepth.data.data_mono import DepthDataLoader
|
||||
import torch.utils.data.distributed
|
||||
import torch.multiprocessing as mp
|
||||
import torch
|
||||
import numpy as np
|
||||
from pprint import pprint
|
||||
import argparse
|
||||
import os
|
||||
|
||||
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
||||
os.environ["WANDB_START_METHOD"] = "thread"
|
||||
|
||||
|
||||
def fix_random_seed(seed: int):
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"):
|
||||
import glob
|
||||
import os
|
||||
|
||||
from zoedepth.models.model_io import load_wts
|
||||
|
||||
if hasattr(config, "checkpoint"):
|
||||
checkpoint = config.checkpoint
|
||||
elif hasattr(config, "ckpt_pattern"):
|
||||
pattern = config.ckpt_pattern
|
||||
matches = glob.glob(os.path.join(
|
||||
checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
|
||||
if not (len(matches) > 0):
|
||||
raise ValueError(f"No matches found for the pattern {pattern}")
|
||||
|
||||
checkpoint = matches[0]
|
||||
|
||||
else:
|
||||
return model
|
||||
model = load_wts(model, checkpoint)
|
||||
print("Loaded weights from {0}".format(checkpoint))
|
||||
return model
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, config):
|
||||
try:
|
||||
seed = config.seed if 'seed' in config and config.seed else 43
|
||||
fix_random_seed(seed)
|
||||
|
||||
config.gpu = gpu
|
||||
|
||||
model = build_model(config)
|
||||
# print(model)
|
||||
|
||||
model = load_ckpt(config, model)
|
||||
model = parallelize(config, model)
|
||||
|
||||
total_params = f"{round(count_parameters(model)/1e6,2)}M"
|
||||
config.total_params = total_params
|
||||
print(f"Total parameters : {total_params}")
|
||||
|
||||
train_loader = DepthDataLoader(config, "train").data
|
||||
test_loader = DepthDataLoader(config, "online_eval").data
|
||||
|
||||
trainer = get_trainer(config)(
|
||||
config, model, train_loader, test_loader, device=config.gpu)
|
||||
|
||||
trainer.train()
|
||||
finally:
|
||||
import wandb
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mp.set_start_method('forkserver')
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="synunet")
|
||||
parser.add_argument("-d", "--dataset", type=str, default='nyu')
|
||||
parser.add_argument("--trainer", type=str, default=None)
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
overwrite_kwargs = parse_unknown(unknown_args)
|
||||
|
||||
overwrite_kwargs["model"] = args.model
|
||||
if args.trainer is not None:
|
||||
overwrite_kwargs["trainer"] = args.trainer
|
||||
|
||||
config = get_config(args.model, "train", args.dataset, **overwrite_kwargs)
|
||||
# git_commit()
|
||||
if config.use_shared_dict:
|
||||
shared_dict = mp.Manager().dict()
|
||||
else:
|
||||
shared_dict = None
|
||||
config.shared_dict = shared_dict
|
||||
|
||||
config.batch_size = config.bs
|
||||
config.mode = 'train'
|
||||
if config.root != "." and not os.path.isdir(config.root):
|
||||
os.makedirs(config.root)
|
||||
|
||||
try:
|
||||
node_str = os.environ['SLURM_JOB_NODELIST'].replace(
|
||||
'[', '').replace(']', '')
|
||||
nodes = node_str.split(',')
|
||||
|
||||
config.world_size = len(nodes)
|
||||
config.rank = int(os.environ['SLURM_PROCID'])
|
||||
# config.save_dir = "/ibex/scratch/bhatsf/videodepth/checkpoints"
|
||||
|
||||
except KeyError as e:
|
||||
# We are NOT using SLURM
|
||||
config.world_size = 1
|
||||
config.rank = 0
|
||||
nodes = ["127.0.0.1"]
|
||||
|
||||
if config.distributed:
|
||||
|
||||
print(config.rank)
|
||||
port = np.random.randint(15000, 15025)
|
||||
config.dist_url = 'tcp://{}:{}'.format(nodes[0], port)
|
||||
print(config.dist_url)
|
||||
config.dist_backend = 'nccl'
|
||||
config.gpu = None
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
config.num_workers = config.workers
|
||||
config.ngpus_per_node = ngpus_per_node
|
||||
print("Config:")
|
||||
pprint(config)
|
||||
if config.distributed:
|
||||
config.world_size = ngpus_per_node * config.world_size
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node,
|
||||
args=(ngpus_per_node, config))
|
||||
else:
|
||||
if ngpus_per_node == 1:
|
||||
config.gpu = 0
|
||||
main_worker(config.gpu, ngpus_per_node, config)
|
||||
Reference in New Issue
Block a user