mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Added training for pixart-a
This commit is contained in:
@@ -19,10 +19,11 @@ class ACTION_TYPES_SLIDER:
|
||||
|
||||
|
||||
class PromptEmbeds:
|
||||
text_embeds: torch.Tensor
|
||||
pooled_embeds: Union[torch.Tensor, None]
|
||||
# text_embeds: torch.Tensor
|
||||
# pooled_embeds: Union[torch.Tensor, None]
|
||||
# attention_mask: Union[torch.Tensor, None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor]) -> None:
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
# xl
|
||||
self.text_embeds = args[0]
|
||||
@@ -32,10 +33,14 @@ class PromptEmbeds:
|
||||
self.text_embeds = args
|
||||
self.pooled_embeds = None
|
||||
|
||||
self.attention_mask = attention_mask
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self.text_embeds = self.text_embeds.to(*args, **kwargs)
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
@@ -43,13 +48,19 @@ class PromptEmbeds:
|
||||
new_embeds.text_embeds = new_embeds.text_embeds.detach()
|
||||
if new_embeds.pooled_embeds is not None:
|
||||
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
|
||||
if new_embeds.attention_mask is not None:
|
||||
new_embeds.attention_mask = new_embeds.attention_mask.detach()
|
||||
return new_embeds
|
||||
|
||||
def clone(self):
|
||||
if self.pooled_embeds is not None:
|
||||
return PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
|
||||
prompt_embeds = PromptEmbeds([self.text_embeds.clone(), self.pooled_embeds.clone()])
|
||||
else:
|
||||
return PromptEmbeds(self.text_embeds.clone())
|
||||
prompt_embeds = PromptEmbeds(self.text_embeds.clone())
|
||||
|
||||
if self.attention_mask is not None:
|
||||
prompt_embeds.attention_mask = self.attention_mask.clone()
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
|
||||
Reference in New Issue
Block a user