diff --git a/models/blip_itm.py b/models/blip_itm.py new file mode 100644 index 0000000..cf354c8 --- /dev/null +++ b/models/blip_itm.py @@ -0,0 +1,76 @@ +from models.med import BertConfig, BertModel +from transformers import BertTokenizer + +import torch +from torch import nn +import torch.nn.functional as F + +from models.blip import create_vit, init_tokenizer, load_checkpoint + +class BLIP_ITM(nn.Module): + def __init__(self, + med_config = 'configs/med_config.json', + image_size = 384, + vit = 'base', + vit_grad_ckpt = False, + vit_ckpt_layer = 0, + embed_dim = 256, + ): + """ + Args: + med_config (str): path for the mixture of encoder-decoder model's configuration file + image_size (int): input image size + vit (str): model size of vision transformer + """ + super().__init__() + + self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) + self.tokenizer = init_tokenizer() + med_config = BertConfig.from_json_file(med_config) + med_config.encoder_width = vision_width + self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) + + text_width = self.text_encoder.config.hidden_size + + self.vision_proj = nn.Linear(vision_width, embed_dim) + self.text_proj = nn.Linear(text_width, embed_dim) + + self.itm_head = nn.Linear(text_width, 2) + + + def forward(self, image, caption, match_head='itm'): + + image_embeds = self.visual_encoder(image) + image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) + + text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, + return_tensors="pt").to(image.device) + + + if match_head=='itm': + output = self.text_encoder(text.input_ids, + attention_mask = text.attention_mask, + encoder_hidden_states = image_embeds, + encoder_attention_mask = image_atts, + return_dict = True, + ) + itm_output = self.itm_head(output.last_hidden_state[:,0,:]) + return itm_output + + elif match_head=='itc': + text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, + return_dict = True, mode = 'text') + image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) + text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) + + sim = image_feat @ text_feat.t() + return sim + + +def blip_itm(pretrained='',**kwargs): + model = BLIP_ITM(**kwargs) + if pretrained: + model,msg = load_checkpoint(model,pretrained) + assert(len(msg.missing_keys)==0) + return model + \ No newline at end of file