Patch clip model for ONNX compatibility (#219)

* Patch clip model for ONNX compatibility

Changes to use INT32 for tokenization, since ONNX doesn't yet support ArgMax(INT64)
Use explicit dimension for norm

* Add compatibility fix for torch 1.7
This commit is contained in:
In-Ho Yi
2022-04-10 16:35:32 -04:00
committed by GitHub
parent 40f5484c1c
commit 7ef63f265b
2 changed files with 9 additions and 5 deletions

View File

@@ -192,7 +192,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
return model, _transform(model.input_resolution.item())
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)
@@ -209,7 +209,8 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
"""
if isinstance(texts, str):
texts = [texts]
@@ -217,7 +218,10 @@ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: b
sot_token = _tokenizer.encoder["<|startoftext|>"]
eot_token = _tokenizer.encoder["<|endoftext|>"]
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
else:
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:

View File

@@ -356,8 +356,8 @@ class CLIP(nn.Module):
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()