From f69a9bc217f6df9213628848b3f9b0b6fc542401 Mon Sep 17 00:00:00 2001 From: Penn Date: Thu, 21 Jul 2022 13:04:35 -0700 Subject: [PATCH] Remove inefficient computation from `AttentionPool2d` Module (#271) * fix inefficient attention computation * remove erroneous formatting * simplified flatten Co-authored-by: Jong Wook Kim --- clip/model.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/clip/model.py b/clip/model.py index 3121dd7..808bf16 100644 --- a/clip/model.py +++ b/clip/model.py @@ -66,11 +66,11 @@ class AttentionPool2d(nn.Module): self.num_heads = num_heads def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, + query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, @@ -88,8 +88,7 @@ class AttentionPool2d(nn.Module): training=self.training, need_weights=False ) - - return x[0] + return x.squeeze(0) class ModifiedResNet(nn.Module):