Intergrate CLIP

This commit is contained in:
layerdiffusion
2024-08-01 12:24:35 -07:00
parent af0b04cc16
commit 4d1be42975
8 changed files with 172 additions and 44 deletions

11
backend/nn/clip.py Normal file
View File

@@ -0,0 +1,11 @@
import torch
from transformers import CLIPTextModel, CLIPTextConfig
class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.transformer = CLIPTextModel(config)
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))