mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import torch
|
|
|
|
|
|
class Decorator(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_tokens: int = 4,
|
|
token_size: int = 4096,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
self.weight: torch.nn.Parameter = torch.nn.Parameter(
|
|
torch.randn(num_tokens, token_size)
|
|
)
|
|
# ensure it is float32
|
|
self.weight.data = self.weight.data.float()
|
|
|
|
def forward(self, text_embeds: torch.Tensor, is_unconditional=False) -> torch.Tensor:
|
|
# make sure the param is float32
|
|
if self.weight.dtype != text_embeds.dtype:
|
|
self.weight.data = self.weight.data.float()
|
|
# expand batch to match text_embeds
|
|
batch_size = text_embeds.shape[0]
|
|
decorator_embeds = self.weight.unsqueeze(0).expand(batch_size, -1, -1)
|
|
if is_unconditional:
|
|
# zero pad the decorator embeds
|
|
decorator_embeds = torch.zeros_like(decorator_embeds)
|
|
|
|
if decorator_embeds.dtype != text_embeds.dtype:
|
|
decorator_embeds = decorator_embeds.to(text_embeds.dtype)
|
|
text_embeds = torch.cat((text_embeds, decorator_embeds), dim=-2)
|
|
|
|
return text_embeds
|