Update blip_pretrain.py

This commit is contained in:
Junnan Li
2022-02-15 11:20:04 +08:00
committed by GitHub
parent ad5eec314c
commit 073b821aa2

View File

@@ -91,7 +91,7 @@ class BLIP_Pretrain(nn.Module):
decoder_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
tie_encoder_decoder_weights(self.text_decoder.bert,self.text_encoder,'','/attention')
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
def forward(self, image, caption, alpha):