mirror of
https://github.com/salesforce/BLIP.git
synced 2026-04-30 20:31:12 +00:00
Update blip_pretrain.py
This commit is contained in:
@@ -91,7 +91,7 @@ class BLIP_Pretrain(nn.Module):
|
|||||||
decoder_config.encoder_width = vision_width
|
decoder_config.encoder_width = vision_width
|
||||||
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
self.text_decoder = BertLMHeadModel.from_pretrained('bert-base-uncased',config=decoder_config)
|
||||||
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
self.text_decoder.resize_token_embeddings(len(self.tokenizer))
|
||||||
tie_encoder_decoder_weights(self.text_decoder.bert,self.text_encoder,'','/attention')
|
tie_encoder_decoder_weights(self.text_encoder,self.text_decoder.bert,'','/attention')
|
||||||
|
|
||||||
|
|
||||||
def forward(self, image, caption, alpha):
|
def forward(self, image, caption, alpha):
|
||||||
|
|||||||
Reference in New Issue
Block a user