mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Preperation for audio
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
|
||||
@@ -1073,6 +1074,15 @@ class GenerateImageConfig:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video format {self.output_ext}")
|
||||
elif self.output_ext in ['wav', 'mp3']:
|
||||
# save audio file
|
||||
torchaudio.save(
|
||||
self.get_image_path(count, max_count),
|
||||
image[0].to('cpu'),
|
||||
sample_rate=48000,
|
||||
format=None,
|
||||
backend=None
|
||||
)
|
||||
else:
|
||||
# TODO save image gen header info for A1111 and us, our seeds probably wont match
|
||||
image.save(self.get_image_path(count, max_count))
|
||||
|
||||
@@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER:
|
||||
class PromptEmbeds:
|
||||
# text_embeds: torch.Tensor
|
||||
# pooled_embeds: Union[torch.Tensor, None]
|
||||
# attention_mask: Union[torch.Tensor, None]
|
||||
# attention_mask: Union[torch.Tensor, List[torch.Tensor], None]
|
||||
|
||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
|
||||
if isinstance(args, list) or isinstance(args, tuple):
|
||||
@@ -43,7 +43,10 @@ class PromptEmbeds:
|
||||
if self.pooled_embeds is not None:
|
||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.to(*args, **kwargs)
|
||||
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
|
||||
self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask]
|
||||
else:
|
||||
self.attention_mask = self.attention_mask.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
@@ -55,7 +58,10 @@ class PromptEmbeds:
|
||||
if new_embeds.pooled_embeds is not None:
|
||||
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
|
||||
if new_embeds.attention_mask is not None:
|
||||
new_embeds.attention_mask = new_embeds.attention_mask.detach()
|
||||
if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple):
|
||||
new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask]
|
||||
else:
|
||||
new_embeds.attention_mask = new_embeds.attention_mask.detach()
|
||||
return new_embeds
|
||||
|
||||
def clone(self):
|
||||
@@ -69,7 +75,10 @@ class PromptEmbeds:
|
||||
prompt_embeds = PromptEmbeds(cloned_text_embeds)
|
||||
|
||||
if self.attention_mask is not None:
|
||||
prompt_embeds.attention_mask = self.attention_mask.clone()
|
||||
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
|
||||
prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask]
|
||||
else:
|
||||
prompt_embeds.attention_mask = self.attention_mask.clone()
|
||||
return prompt_embeds
|
||||
|
||||
def expand_to_batch(self, batch_size):
|
||||
@@ -89,7 +98,10 @@ class PromptEmbeds:
|
||||
if pe.pooled_embeds is not None:
|
||||
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
|
||||
if pe.attention_mask is not None:
|
||||
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
|
||||
pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask]
|
||||
else:
|
||||
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||
return pe
|
||||
|
||||
def save(self, path: str):
|
||||
@@ -108,7 +120,11 @@ class PromptEmbeds:
|
||||
if pe.pooled_embeds is not None:
|
||||
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
|
||||
if pe.attention_mask is not None:
|
||||
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
||||
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
|
||||
for i, attn in enumerate(pe.attention_mask):
|
||||
state_dict[f"attention_mask_{i}"] = attn.cpu()
|
||||
else:
|
||||
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
save_file(state_dict, path)
|
||||
|
||||
@@ -122,7 +138,7 @@ class PromptEmbeds:
|
||||
state_dict = load_file(path, device='cpu')
|
||||
text_embeds = []
|
||||
pooled_embeds = None
|
||||
attention_mask = None
|
||||
attention_mask = []
|
||||
for key in sorted(state_dict.keys()):
|
||||
if key.startswith("text_embed_"):
|
||||
text_embeds.append(state_dict[key])
|
||||
@@ -130,19 +146,25 @@ class PromptEmbeds:
|
||||
text_embeds.append(state_dict[key])
|
||||
elif key == "pooled_embed":
|
||||
pooled_embeds = state_dict[key]
|
||||
elif key.startswith("attention_mask_"):
|
||||
attention_mask.append(state_dict[key])
|
||||
elif key == "attention_mask":
|
||||
attention_mask = state_dict[key]
|
||||
attention_mask.append(state_dict[key])
|
||||
pe = cls(None)
|
||||
pe.text_embeds = text_embeds
|
||||
if len(text_embeds) == 1:
|
||||
pe.text_embeds = text_embeds[0]
|
||||
if pooled_embeds is not None:
|
||||
pe.pooled_embeds = pooled_embeds
|
||||
if attention_mask is not None:
|
||||
pe.attention_mask = attention_mask
|
||||
if len(attention_mask) > 0:
|
||||
if len(attention_mask) == 1:
|
||||
pe.attention_mask = attention_mask[0]
|
||||
else:
|
||||
pe.attention_mask = attention_mask
|
||||
return pe
|
||||
|
||||
|
||||
|
||||
class EncodedPromptPair:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user