Added training for pixart-a

This commit is contained in:
Jaret Burkett
2024-02-13 16:00:04 -07:00
parent 4ec4025cbb
commit 93b52932c1
10 changed files with 288 additions and 24 deletions

View File

@@ -19,10 +19,11 @@ class ACTION_TYPES_SLIDER:
class PromptEmbeds:
text_embeds: torch.Tensor
pooled_embeds: Union[torch.Tensor, None]
# text_embeds: torch.Tensor
# pooled_embeds: Union[torch.Tensor, None]
# attention_mask: Union[torch.Tensor, None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
if isinstance(args, list) or isinstance(args, tuple):
# xl
self.text_embeds = args[0]
@@ -32,10 +33,14 @@ class PromptEmbeds:
self.text_embeds = args
self.pooled_embeds = None
self.attention_mask = attention_mask
def to(self, *args, **kwargs):
self.text_embeds = self.text_embeds.to(*args, **kwargs)
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(*args, **kwargs)
return self
def detach(self):
@@ -43,13 +48,19 @@ class PromptEmbeds:
new_embeds.text_embeds = new_embeds.text_embeds.detach()
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None:
new_embeds.attention_mask = new_embeds.attention_mask.detach()
return new_embeds
def clone(self):
if self.pooled_embeds is not None:
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
else:
return PromptEmbeds(self.text_embeds.clone())
prompt_embeds = PromptEmbeds(self.text_embeds.clone())
if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone()
return prompt_embeds
class EncodedPromptPair: