diff --git a/configs/retrieval_msrvtt.yaml b/configs/retrieval_msrvtt.yaml new file mode 100644 index 0000000..395f625 --- /dev/null +++ b/configs/retrieval_msrvtt.yaml @@ -0,0 +1,12 @@ +video_root: '/export/share/dongxuli/data/msrvtt_retrieval/videos' +ann_root: 'annotation' + +# set pretrained as a file path or an url +pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth' + +# size of vit model; base or large +vit: 'base' +batch_size: 64 +k_test: 128 +image_size: 384 +num_frm_test: 8 \ No newline at end of file diff --git a/data/video_dataset.py b/data/video_dataset.py new file mode 100644 index 0000000..0a6f8a6 --- /dev/null +++ b/data/video_dataset.py @@ -0,0 +1,110 @@ +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image +import torch +import numpy as np +import random +import decord +from decord import VideoReader +import json +import os +from data.utils import pre_caption + +decord.bridge.set_bridge("torch") + +class ImageNorm(object): + """Apply Normalization to Image Pixels on GPU + """ + def __init__(self, mean, std): + self.mean = torch.tensor(mean).view(1, 3, 1, 1) + self.std = torch.tensor(std).view(1, 3, 1, 1) + + def __call__(self, img): + + if torch.max(img) > 1 and self.mean.max() <= 1: + img.div_(255.) + return img.sub_(self.mean).div_(self.std) + +def load_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(l.strip("\n")) for l in f.readlines()] + + +class VideoDataset(Dataset): + + def __init__(self, video_root, ann_root, num_frm=4, frm_sampling_strategy="rand", max_img_size=384, video_fmt='.mp4'): + ''' + image_root (string): Root directory of video + ann_root (string): directory to store the annotation file + ''' + url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/msrvtt_test.jsonl' + filename = 'msrvtt_test.jsonl' + + download_url(url,ann_root) + self.annotation = load_jsonl(os.path.join(ann_root,filename)) + + self.num_frm = num_frm + self.frm_sampling_strategy = frm_sampling_strategy + self.max_img_size = max_img_size + self.video_root = video_root + self.video_fmt = video_fmt + self.img_norm = ImageNorm(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) + + self.text = [pre_caption(ann['caption'],40) for ann in self.annotation] + self.txt2video = [i for i in range(len(self.annotation))] + self.video2txt = self.txt2video + + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + video_path = os.path.join(self.video_root, ann['clip_name'] + self.video_fmt) + + vid_frm_array = self._load_video_from_path_decord(video_path, height=self.max_img_size, width=self.max_img_size) + + video = self.img_norm(vid_frm_array.float()) + + return video, ann['clip_name'] + + + + def _load_video_from_path_decord(self, video_path, height=None, width=None, start_time=None, end_time=None, fps=-1): + try: + if not height or not width: + vr = VideoReader(video_path) + else: + vr = VideoReader(video_path, width=width, height=height) + + vlen = len(vr) + + if start_time or end_time: + assert fps > 0, 'must provide video fps if specifying start and end time.' + + start_idx = min(int(start_time * fps), vlen) + end_idx = min(int(end_time * fps), vlen) + else: + start_idx, end_idx = 0, vlen + + if self.frm_sampling_strategy == 'uniform': + frame_indices = np.arange(start_idx, end_idx, vlen / self.num_frm, dtype=int) + elif self.frm_sampling_strategy == 'rand': + frame_indices = sorted(random.sample(range(vlen), self.num_frm)) + elif self.frm_sampling_strategy == 'headtail': + frame_indices_head = sorted(random.sample(range(vlen // 2), self.num_frm // 2)) + frame_indices_tail = sorted(random.sample(range(vlen // 2, vlen), self.num_frm // 2)) + frame_indices = frame_indices_head + frame_indices_tail + else: + raise NotImplementedError('Invalid sampling strategy {} '.format(self.frm_sampling_strategy)) + + raw_sample_frms = vr.get_batch(frame_indices) + except Exception as e: + return None + + raw_sample_frms = raw_sample_frms.permute(0, 3, 1, 2) + + return raw_sample_frms diff --git a/eval_retrieval_video.py b/eval_retrieval_video.py new file mode 100644 index 0000000..07ebab7 --- /dev/null +++ b/eval_retrieval_video.py @@ -0,0 +1,250 @@ +''' + * Copyright (c) 2022, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li +''' +import argparse +import os +import ruamel_yaml as yaml +import numpy as np +import random +import time +import datetime +import json +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from torch.utils.data import DataLoader + +from models.blip_retrieval import blip_retrieval +import utils +from data.video_dataset import VideoDataset + + +@torch.no_grad() +def evaluation(model, data_loader, tokenizer, device, config): + # test + model.eval() + + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Evaluation:' + + print('Computing features for evaluation...') + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i: min(num_text, i+text_bs)] + text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device) + text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text') + text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:])) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds,dim=0) + text_ids = torch.cat(text_ids,dim=0) + text_atts = torch.cat(text_atts,dim=0) + text_ids[:,0] = tokenizer.additional_special_tokens_ids[0] + + video_feats = [] + video_embeds = [] + for video, video_id in data_loader: + + B,N,C,W,H = video.size() + video = video.view(-1,C,W,H) + video = video.to(device,non_blocking=True) + video_feat = model.visual_encoder(video) + video_embed = model.vision_proj(video_feat[:,0,:]) + video_embed = video_embed.view(B,N,-1).mean(dim=1) + video_embed = F.normalize(video_embed,dim=-1) + + video_feat = video_feat.view(B,-1,video_feat.shape[-1]) + video_feats.append(video_feat.cpu()) + video_embeds.append(video_embed) + + video_feats = torch.cat(video_feats,dim=0) + video_embeds = torch.cat(video_embeds,dim=0) + + sims_matrix = video_embeds @ text_embeds.t() + score_matrix_v2t = torch.full((len(texts),len(texts)),-100.0).to(device) + + num_tasks = utils.get_world_size() + rank = utils.get_rank() + step = sims_matrix.size(0)//num_tasks + 1 + start = rank*step + end = min(sims_matrix.size(0),start+step) + + for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): + topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) + + encoder_output = video_feats[start+i].repeat(config['k_test'],1,1).to(device,non_blocking=True) + encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) + output = model.text_encoder(text_ids[topk_idx], + attention_mask = text_atts[topk_idx], + encoder_hidden_states = encoder_output, + encoder_attention_mask = encoder_att, + return_dict = True, + ) + score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] + score_matrix_v2t[start+i,topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2v = torch.full((len(texts),len(texts)),-100.0).to(device) + + step = sims_matrix.size(0)//num_tasks + 1 + start = rank*step + end = min(sims_matrix.size(0),start+step) + + for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): + + topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0) + encoder_output = video_feats[topk_idx].to(device,non_blocking=True) + encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device,non_blocking=True) + output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1), + attention_mask = text_atts[start+i].repeat(config['k_test'],1), + encoder_hidden_states = encoder_output, + encoder_attention_mask = encoder_att, + return_dict = True, + ) + score = model.itm_head(output.last_hidden_state[:,0,:])[:,1] + score_matrix_t2v[start+i,topk_idx] = score + topk_sim + + if args.distributed: + dist.barrier() + torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM) + torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Evaluation time {}'.format(total_time_str)) + + return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy() + + + +@torch.no_grad() +def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt): + + #Video->Text + ranks = np.zeros(scores_v2t.shape[0]) + for index,score in enumerate(scores_v2t): + inds = np.argsort(score)[::-1] + ranks[index] = np.where(inds == vid2txt[index])[0][0] + + # Compute metrics + tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + #Text->Video + ranks = np.zeros(scores_t2v.shape[0]) + + for index,score in enumerate(scores_t2v): + inds = np.argsort(score)[::-1] + ranks[index] = np.where(inds == txt2vmg[index])[0][0] + + mdR = np.median(ranks+1) + + # Compute metrics + vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) + vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) + vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) + + tr_mean = (tr1 + tr5 + tr10) / 3 + vr_mean = (vr1 + vr5 + vr10) / 3 + r_mean = (tr_mean + vr_mean) / 2 + + eval_result = {'txt_r1': tr1, + 'txt_r5': tr5, + 'txt_r10': tr10, + 'txt_r_mean': tr_mean, + 'vid_r1': vr1, + 'vid_r5': vr5, + 'vid_r10': vr10, + 'vid_r_mean': vr_mean, + 'vid_mdR': mdR, + 'r_mean': r_mean} + return eval_result + + + + +def main(args, config): + utils.init_distributed_mode(args) + + device = torch.device(args.device) + + # fix the seed for reproducibility + seed = args.seed + utils.get_rank() + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + cudnn.benchmark = True + + #### Dataset #### + print("Creating retrieval dataset") + test_dataset = VideoDataset(config['video_root'],config['ann_root'],num_frm=config['num_frm_test'], + max_img_size=config['image_size'], frm_sampling_strategy='uniform') + + test_loader = DataLoader( + test_dataset, + batch_size=config['batch_size'], + num_workers=4, + pin_memory=True, + drop_last=False, + shuffle=False, + ) + + #### Model #### + print("Creating model") + model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit']) + + model = model.to(device) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + model_without_ddp = model.module + + score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, model_without_ddp.tokenizer, device, config) + + if utils.is_main_process(): + + test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt) + print(test_result) + + log_stats = {**{f'{k}': v for k, v in test_result.items()},} + with open(os.path.join(args.output_dir, "test_result.txt"),"a") as f: + f.write(json.dumps(log_stats) + "\n") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml') + parser.add_argument('--output_dir', default='output/Retrieval_msrvtt') + parser.add_argument('--device', default='cuda') + parser.add_argument('--seed', default=42, type=int) + parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + parser.add_argument('--distributed', default=True, type=bool) + args = parser.parse_args() + + config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) + + main(args, config) \ No newline at end of file