From f5eacc9f082f53372b5534d0b0b9a7228dd17caf Mon Sep 17 00:00:00 2001 From: root <“junnan.li@salesforce.com”> Date: Thu, 27 Jan 2022 12:51:05 +0000 Subject: [PATCH] add datasets --- data/__init__.py | 101 +++++++++++++++++++++++++++ data/coco_karpathy_dataset.py | 126 ++++++++++++++++++++++++++++++++++ data/flickr30k_dataset.py | 93 +++++++++++++++++++++++++ data/nlvr_dataset.py | 78 +++++++++++++++++++++ data/nocaps_dataset.py | 32 +++++++++ data/pretrain_dataset.py | 59 ++++++++++++++++ data/utils.py | 112 ++++++++++++++++++++++++++++++ data/vqa_dataset.py | 88 ++++++++++++++++++++++++ 8 files changed, 689 insertions(+) create mode 100644 data/__init__.py create mode 100644 data/coco_karpathy_dataset.py create mode 100644 data/flickr30k_dataset.py create mode 100644 data/nlvr_dataset.py create mode 100644 data/nocaps_dataset.py create mode 100644 data/pretrain_dataset.py create mode 100644 data/utils.py create mode 100644 data/vqa_dataset.py diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..0be209a --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,101 @@ +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from torchvision.transforms.functional import InterpolationMode + +from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval +from data.nocaps_dataset import nocaps_eval +from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval +from data.vqa_dataset import vqa_dataset +from data.nlvr_dataset import nlvr_dataset +from data.pretrain_dataset import pretrain_dataset +from transform.randaugment import RandomAugment + +def create_dataset(dataset, config, min_scale=0.5): + + normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + + transform_train = transforms.Compose([ + transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(), + RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', + 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), + transforms.ToTensor(), + normalize, + ]) + transform_test = transforms.Compose([ + transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + normalize, + ]) + + if dataset=='pretrain': + dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train) + return dataset + + elif dataset=='caption_coco': + train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt']) + val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='nocaps': + val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return val_dataset, test_dataset + + elif dataset=='retrieval_coco': + train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root']) + val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='retrieval_flickr': + train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root']) + val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val') + test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test') + return train_dataset, val_dataset, test_dataset + + elif dataset=='vqa': + train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'], + train_files = config['train_files'], split='train') + test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test') + return train_dataset, test_dataset + + elif dataset=='nlvr': + train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train') + val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val') + test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test') + return train_dataset, val_dataset, test_dataset + + +def create_sampler(datasets, shuffles, num_tasks, global_rank): + samplers = [] + for dataset,shuffle in zip(datasets,shuffles): + sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) + samplers.append(sampler) + return samplers + + +def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): + loaders = [] + for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): + if is_train: + shuffle = (sampler is None) + drop_last = True + else: + shuffle = False + drop_last = False + loader = DataLoader( + dataset, + batch_size=bs, + num_workers=n_worker, + pin_memory=True, + sampler=sampler, + shuffle=shuffle, + collate_fn=collate_fn, + drop_last=drop_last, + ) + loaders.append(loader) + return loaders + diff --git a/data/coco_karpathy_dataset.py b/data/coco_karpathy_dataset.py new file mode 100644 index 0000000..a34d292 --- /dev/null +++ b/data/coco_karpathy_dataset.py @@ -0,0 +1,126 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class coco_karpathy_train(Dataset): + def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + ''' + url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json' + filename = 'coco_karpathy_train.json' + + download_url(url,ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) + self.transform = transform + self.image_root = image_root + self.max_words = max_words + self.prompt = prompt + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann['image_id'] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + caption = self.prompt+pre_caption(ann['caption'], self.max_words) + + return image, caption, self.img_ids[ann['image_id']] + + +class coco_karpathy_caption_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} + filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1] + + return image, int(img_id) + + +class coco_karpathy_retrieval_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split, max_words=30): + ''' + image_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'} + filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann['image']) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann['caption']): + self.text.append(pre_caption(caption,max_words)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + image_path = os.path.join(self.image_root, self.annotation[index]['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, index \ No newline at end of file diff --git a/data/flickr30k_dataset.py b/data/flickr30k_dataset.py new file mode 100644 index 0000000..018ab38 --- /dev/null +++ b/data/flickr30k_dataset.py @@ -0,0 +1,93 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class flickr30k_train(Dataset): + def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''): + ''' + image_root (string): Root directory of images (e.g. flickr30k/) + ann_root (string): directory to store the annotation file + ''' + url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json' + filename = 'flickr30k_train.json' + + download_url(url,ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filename),'r')) + self.transform = transform + self.image_root = image_root + self.max_words = max_words + self.prompt = prompt + + self.img_ids = {} + n = 0 + for ann in self.annotation: + img_id = ann['image_id'] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + caption = self.prompt+pre_caption(ann['caption'], self.max_words) + + return image, caption, self.img_ids[ann['image_id']] + + +class flickr30k_retrieval_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split, max_words=30): + ''' + image_root (string): Root directory of images (e.g. flickr30k/) + ann_root (string): directory to store the annotation file + split (string): val or test + ''' + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'} + filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + self.text = [] + self.image = [] + self.txt2img = {} + self.img2txt = {} + + txt_id = 0 + for img_id, ann in enumerate(self.annotation): + self.image.append(ann['image']) + self.img2txt[img_id] = [] + for i, caption in enumerate(ann['caption']): + self.text.append(pre_caption(caption,max_words)) + self.img2txt[img_id].append(txt_id) + self.txt2img[txt_id] = img_id + txt_id += 1 + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + image_path = os.path.join(self.image_root, self.annotation[index]['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, index \ No newline at end of file diff --git a/data/nlvr_dataset.py b/data/nlvr_dataset.py new file mode 100644 index 0000000..a8d6b2d --- /dev/null +++ b/data/nlvr_dataset.py @@ -0,0 +1,78 @@ +import os +import json +import random + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +from data.utils import pre_caption + +class nlvr_dataset(Dataset): + def __init__(self, transform, image_root, ann_root, split): + ''' + image_root (string): Root directory of images + ann_root (string): directory to store the annotation file + split (string): train, val or test + ''' + urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json', + 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'} + filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'} + + download_url(urls[split],ann_root) + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + + self.transform = transform + self.image_root = image_root + + + def __len__(self): + return len(self.annotation) + + + def __getitem__(self, index): + + ann = self.annotation[index] + + image0_path = os.path.join(self.image_root,ann['images'][0]) + image0 = Image.open(image0_path).convert('RGB') + image0 = self.transform(image0) + + image1_path = os.path.join(self.image_root,ann['images'][1]) + image1 = Image.open(image1_path).convert('RGB') + image1 = self.transform(image1) + + sentence = pre_caption(ann['sentence'], 40) + + if ann['label']=='True': + label = 1 + else: + label = 0 + + words = sentence.split(' ') + + if 'left' not in words and 'right' not in words: + if random.random()<0.5: + return image0, image1, sentence, label + else: + return image1, image0, sentence, label + else: + if random.random()<0.5: + return image0, image1, sentence, label + else: + new_words = [] + for word in words: + if word=='left': + new_words.append('right') + elif word=='right': + new_words.append('left') + else: + new_words.append(word) + + sentence = ' '.join(new_words) + return image1, image0, sentence, label + + + \ No newline at end of file diff --git a/data/nocaps_dataset.py b/data/nocaps_dataset.py new file mode 100644 index 0000000..ba0bed0 --- /dev/null +++ b/data/nocaps_dataset.py @@ -0,0 +1,32 @@ +import os +import json + +from torch.utils.data import Dataset +from torchvision.datasets.utils import download_url + +from PIL import Image + +class nocaps_eval(Dataset): + def __init__(self, transform, image_root, ann_root, split): + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'} + filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'} + + download_url(urls[split],ann_root) + + self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r')) + self.transform = transform + self.image_root = image_root + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image_path = os.path.join(self.image_root,ann['image']) + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + return image, int(ann['img_id']) \ No newline at end of file diff --git a/data/pretrain_dataset.py b/data/pretrain_dataset.py new file mode 100644 index 0000000..703d543 --- /dev/null +++ b/data/pretrain_dataset.py @@ -0,0 +1,59 @@ +import json +import os +import random + +from torch.utils.data import Dataset + +from PIL import Image +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = None + +from data.utils import pre_caption +import os,glob + +class pretrain_dataset(Dataset): + def __init__(self, ann_file, laion_path, transform): + + self.ann_pretrain = [] + for f in ann_file: + print('loading '+f) + ann = json.load(open(f,'r')) + self.ann_pretrain += ann + + self.laion_path = laion_path + if self.laion_path: + self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) + + print('loading '+self.laion_files[0]) + with open(self.laion_files[0],'r') as f: + self.ann_laion = json.load(f) + + self.annotation = self.ann_pretrain + self.ann_laion + else: + self.annotation = self.ann_pretrain + + self.transform = transform + + + def reload_laion(self, epoch): + n = epoch%len(self.laion_files) + print('loading '+self.laion_files[n]) + with open(self.laion_files[n],'r') as f: + self.ann_laion = json.load(f) + + self.annotation = self.ann_pretrain + self.ann_laion + + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + image = Image.open(ann['image']).convert('RGB') + image = self.transform(image) + caption = pre_caption(ann['caption'],30) + + return image, caption \ No newline at end of file diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 0000000..6288948 --- /dev/null +++ b/data/utils.py @@ -0,0 +1,112 @@ +import re +import json +import os + +import torch +import torch.distributed as dist + +import utils + +def pre_caption(caption,max_words=50): + caption = re.sub( + r"([.!\"()*#:;~])", + ' ', + caption.lower(), + ) + caption = re.sub( + r"\s{2,}", + ' ', + caption, + ) + caption = caption.rstrip('\n') + caption = caption.strip(' ') + + #truncate caption + caption_words = caption.split(' ') + if len(caption_words)>max_words: + caption = ' '.join(caption_words[:max_words]) + + return caption + +def pre_question(question,max_ques_words=50): + question = re.sub( + r"([.!\"()*#:;~])", + '', + question.lower(), + ) + question = question.rstrip(' ') + + #truncate question + question_words = question.split(' ') + if len(question_words)>max_ques_words: + question = ' '.join(question_words[:max_ques_words]) + + return question + + +def save_result(result, result_dir, filename, remove_duplicate=''): + result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) + final_result_file = os.path.join(result_dir, '%s.json'%filename) + + json.dump(result,open(result_file,'w')) + + dist.barrier() + + if utils.is_main_process(): + # combine results from all processes + result = [] + + for rank in range(utils.get_world_size()): + result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) + res = json.load(open(result_file,'r')) + result += res + + if remove_duplicate: + result_new = [] + id_list = [] + for res in result: + if res[remove_duplicate] not in id_list: + id_list.append(res[remove_duplicate]) + result_new.append(res) + result = result_new + + json.dump(result,open(final_result_file,'w')) + print('result file saved to %s'%final_result_file) + + return final_result_file + + + +from pycocotools.coco import COCO +from pycocoevalcap.eval import COCOEvalCap +from torchvision.datasets.utils import download_url + +def coco_caption_eval(coco_gt_root, results_file, split): + urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', + 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} + filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} + + download_url(urls[split],coco_gt_root) + annotation_file = os.path.join(coco_gt_root,filenames[split]) + + # create coco object and coco_result object + coco = COCO(annotation_file) + coco_result = coco.loadRes(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # evaluate on a subset of images by setting + # coco_eval.params['image_id'] = coco_result.getImgIds() + # please remove this line when evaluating the full validation set + # coco_eval.params['image_id'] = coco_result.getImgIds() + + # evaluate results + # SPICE will take a few minutes the first time, but speeds up due to caching + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f'{metric}: {score:.3f}') + + return coco_eval \ No newline at end of file diff --git a/data/vqa_dataset.py b/data/vqa_dataset.py new file mode 100644 index 0000000..92ec1df --- /dev/null +++ b/data/vqa_dataset.py @@ -0,0 +1,88 @@ +import os +import json +import random +from PIL import Image + +import torch +from torch.utils.data import Dataset +from data.utils import pre_question + +from torchvision.datasets.utils import download_url + +class vqa_dataset(Dataset): + def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"): + self.split = split + + self.transform = transform + self.vqa_root = vqa_root + self.vg_root = vg_root + + if split=='train': + urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json', + 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json', + 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'} + + self.annotation = [] + for f in train_files: + download_url(urls[f],ann_root) + self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r')) + else: + download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root) + self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r')) + + download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root) + self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r')) + + + def __len__(self): + return len(self.annotation) + + def __getitem__(self, index): + + ann = self.annotation[index] + + if ann['dataset']=='vqa': + image_path = os.path.join(self.vqa_root,ann['image']) + elif ann['dataset']=='vg': + image_path = os.path.join(self.vg_root,ann['image']) + + image = Image.open(image_path).convert('RGB') + image = self.transform(image) + + if self.split == 'test': + question = pre_question(ann['question']) + question_id = ann['question_id'] + return image, question, question_id + + + elif self.split=='train': + + question = pre_question(ann['question']) + + if ann['dataset']=='vqa': + answer_weight = {} + for answer in ann['answer']: + if answer in answer_weight.keys(): + answer_weight[answer] += 1/len(ann['answer']) + else: + answer_weight[answer] = 1/len(ann['answer']) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + elif ann['dataset']=='vg': + answers = [ann['answer']] + weights = [0.2] + + return image, question, answers, weights + + +def vqa_collate_fn(batch): + image_list, question_list, answer_list, weight_list, n = [], [], [], [], [] + for image, question, answer, weights in batch: + image_list.append(image) + question_list.append(question) + weight_list += weights + answer_list += answer + n.append(len(answer)) + return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n \ No newline at end of file