Files
stable-diffusion-webui-forge/backend/nn/clip.py
2024-08-05 03:08:17 -07:00

13 lines
442 B
Python

import torch
from transformers import CLIPTextModel, CLIPTextConfig
class IntegratedCLIP(torch.nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.transformer = CLIPTextModel(config)
embed_dim = config.hidden_size
self.transformer.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
self.transformer.text_projection.weight.copy_(torch.eye(embed_dim))