mirror of
https://github.com/openai/CLIP.git
synced 2026-01-26 15:29:48 +00:00
Remove inefficient computation from AttentionPool2d Module (#271)
* fix inefficient attention computation * remove erroneous formatting * simplified flatten Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
This commit is contained in:
@@ -66,11 +66,11 @@ class AttentionPool2d(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x, key=x, value=x,
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
@@ -88,8 +88,7 @@ class AttentionPool2d(nn.Module):
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
|
||||
return x[0]
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user