mirror of
https://github.com/openai/CLIP.git
synced 2026-01-26 15:29:48 +00:00
Patch clip model for ONNX compatibility (#219)
* Patch clip model for ONNX compatibility Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64) Use explicit dimension for norm * Add compatibility fix for torch 1.7
This commit is contained in:
10
clip/clip.py
10
clip/clip.py
@@ -192,7 +192,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
|
||||
return model, _transform(model.input_resolution.item())
|
||||
|
||||
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
|
||||
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
||||
"""
|
||||
Returns the tokenized representation of given input string(s)
|
||||
|
||||
@@ -209,7 +209,8 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
|
||||
|
||||
Returns
|
||||
-------
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
||||
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
||||
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
||||
"""
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
@@ -217,7 +218,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
|
||||
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
||||
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
||||
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
||||
else:
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
|
||||
@@ -356,8 +356,8 @@ class CLIP(nn.Module):
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
|
||||
Reference in New Issue
Block a user