mirror of
https://github.com/huchenlei/Depth-Anything.git
synced 2026-01-26 15:29:46 +00:00
177 lines
5.6 KiB
Python
177 lines
5.6 KiB
Python
# MIT License
|
|
|
|
# Copyright (c) 2022 Intelligent Systems Lab Org
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in all
|
|
# copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
# File author: Shariq Farooq Bhat
|
|
|
|
from zoedepth.utils.misc import count_parameters, parallelize
|
|
from zoedepth.utils.config import get_config
|
|
from zoedepth.utils.arg_utils import parse_unknown
|
|
from zoedepth.trainers.builder import get_trainer
|
|
from zoedepth.models.builder import build_model
|
|
from zoedepth.data.data_mono import DepthDataLoader
|
|
import torch.utils.data.distributed
|
|
import torch.multiprocessing as mp
|
|
import torch
|
|
import numpy as np
|
|
from pprint import pprint
|
|
import argparse
|
|
import os
|
|
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
|
os.environ["WANDB_START_METHOD"] = "thread"
|
|
|
|
|
|
def fix_random_seed(seed: int):
|
|
import random
|
|
|
|
import numpy
|
|
import torch
|
|
|
|
random.seed(seed)
|
|
numpy.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
def load_ckpt(config, model, checkpoint_dir="./checkpoints", ckpt_type="best"):
|
|
import glob
|
|
import os
|
|
|
|
from zoedepth.models.model_io import load_wts
|
|
|
|
if hasattr(config, "checkpoint"):
|
|
checkpoint = config.checkpoint
|
|
elif hasattr(config, "ckpt_pattern"):
|
|
pattern = config.ckpt_pattern
|
|
matches = glob.glob(os.path.join(
|
|
checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
|
|
if not (len(matches) > 0):
|
|
raise ValueError(f"No matches found for the pattern {pattern}")
|
|
|
|
checkpoint = matches[0]
|
|
|
|
else:
|
|
return model
|
|
model = load_wts(model, checkpoint)
|
|
print("Loaded weights from {0}".format(checkpoint))
|
|
return model
|
|
|
|
|
|
def main_worker(gpu, ngpus_per_node, config):
|
|
try:
|
|
seed = config.seed if 'seed' in config and config.seed else 43
|
|
fix_random_seed(seed)
|
|
|
|
config.gpu = gpu
|
|
|
|
model = build_model(config)
|
|
# print(model)
|
|
|
|
model = load_ckpt(config, model)
|
|
model = parallelize(config, model)
|
|
|
|
total_params = f"{round(count_parameters(model)/1e6,2)}M"
|
|
config.total_params = total_params
|
|
print(f"Total parameters : {total_params}")
|
|
|
|
train_loader = DepthDataLoader(config, "train").data
|
|
test_loader = DepthDataLoader(config, "online_eval").data
|
|
|
|
trainer = get_trainer(config)(
|
|
config, model, train_loader, test_loader, device=config.gpu)
|
|
|
|
trainer.train()
|
|
finally:
|
|
import wandb
|
|
wandb.finish()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
mp.set_start_method('forkserver')
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-m", "--model", type=str, default="synunet")
|
|
parser.add_argument("-d", "--dataset", type=str, default='nyu')
|
|
parser.add_argument("--trainer", type=str, default=None)
|
|
|
|
args, unknown_args = parser.parse_known_args()
|
|
overwrite_kwargs = parse_unknown(unknown_args)
|
|
|
|
overwrite_kwargs["model"] = args.model
|
|
if args.trainer is not None:
|
|
overwrite_kwargs["trainer"] = args.trainer
|
|
|
|
config = get_config(args.model, "train", args.dataset, **overwrite_kwargs)
|
|
# git_commit()
|
|
if config.use_shared_dict:
|
|
shared_dict = mp.Manager().dict()
|
|
else:
|
|
shared_dict = None
|
|
config.shared_dict = shared_dict
|
|
|
|
config.batch_size = config.bs
|
|
config.mode = 'train'
|
|
if config.root != "." and not os.path.isdir(config.root):
|
|
os.makedirs(config.root)
|
|
|
|
try:
|
|
node_str = os.environ['SLURM_JOB_NODELIST'].replace(
|
|
'[', '').replace(']', '')
|
|
nodes = node_str.split(',')
|
|
|
|
config.world_size = len(nodes)
|
|
config.rank = int(os.environ['SLURM_PROCID'])
|
|
# config.save_dir = "/ibex/scratch/bhatsf/videodepth/checkpoints"
|
|
|
|
except KeyError as e:
|
|
# We are NOT using SLURM
|
|
config.world_size = 1
|
|
config.rank = 0
|
|
nodes = ["127.0.0.1"]
|
|
|
|
if config.distributed:
|
|
|
|
print(config.rank)
|
|
port = np.random.randint(15000, 15025)
|
|
config.dist_url = 'tcp://{}:{}'.format(nodes[0], port)
|
|
print(config.dist_url)
|
|
config.dist_backend = 'nccl'
|
|
config.gpu = None
|
|
|
|
ngpus_per_node = torch.cuda.device_count()
|
|
config.num_workers = config.workers
|
|
config.ngpus_per_node = ngpus_per_node
|
|
print("Config:")
|
|
pprint(config)
|
|
if config.distributed:
|
|
config.world_size = ngpus_per_node * config.world_size
|
|
mp.spawn(main_worker, nprocs=ngpus_per_node,
|
|
args=(ngpus_per_node, config))
|
|
else:
|
|
if ngpus_per_node == 1:
|
|
config.gpu = 0
|
|
main_worker(config.gpu, ngpus_per_node, config)
|