mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Preperation for audio
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
|
|||||||
import random
|
import random
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
|
|
||||||
@@ -1073,6 +1074,15 @@ class GenerateImageConfig:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video format {self.output_ext}")
|
raise ValueError(f"Unsupported video format {self.output_ext}")
|
||||||
|
elif self.output_ext in ['wav', 'mp3']:
|
||||||
|
# save audio file
|
||||||
|
torchaudio.save(
|
||||||
|
self.get_image_path(count, max_count),
|
||||||
|
image[0].to('cpu'),
|
||||||
|
sample_rate=48000,
|
||||||
|
format=None,
|
||||||
|
backend=None
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# TODO save image gen header info for A1111 and us, our seeds probably wont match
|
# TODO save image gen header info for A1111 and us, our seeds probably wont match
|
||||||
image.save(self.get_image_path(count, max_count))
|
image.save(self.get_image_path(count, max_count))
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER:
|
|||||||
class PromptEmbeds:
|
class PromptEmbeds:
|
||||||
# text_embeds: torch.Tensor
|
# text_embeds: torch.Tensor
|
||||||
# pooled_embeds: Union[torch.Tensor, None]
|
# pooled_embeds: Union[torch.Tensor, None]
|
||||||
# attention_mask: Union[torch.Tensor, None]
|
# attention_mask: Union[torch.Tensor, List[torch.Tensor], None]
|
||||||
|
|
||||||
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
|
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
|
||||||
if isinstance(args, list) or isinstance(args, tuple):
|
if isinstance(args, list) or isinstance(args, tuple):
|
||||||
@@ -43,7 +43,10 @@ class PromptEmbeds:
|
|||||||
if self.pooled_embeds is not None:
|
if self.pooled_embeds is not None:
|
||||||
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
|
||||||
if self.attention_mask is not None:
|
if self.attention_mask is not None:
|
||||||
self.attention_mask = self.attention_mask.to(*args, **kwargs)
|
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
|
||||||
|
self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask]
|
||||||
|
else:
|
||||||
|
self.attention_mask = self.attention_mask.to(*args, **kwargs)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def detach(self):
|
def detach(self):
|
||||||
@@ -55,7 +58,10 @@ class PromptEmbeds:
|
|||||||
if new_embeds.pooled_embeds is not None:
|
if new_embeds.pooled_embeds is not None:
|
||||||
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
|
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
|
||||||
if new_embeds.attention_mask is not None:
|
if new_embeds.attention_mask is not None:
|
||||||
new_embeds.attention_mask = new_embeds.attention_mask.detach()
|
if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple):
|
||||||
|
new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask]
|
||||||
|
else:
|
||||||
|
new_embeds.attention_mask = new_embeds.attention_mask.detach()
|
||||||
return new_embeds
|
return new_embeds
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
@@ -69,7 +75,10 @@ class PromptEmbeds:
|
|||||||
prompt_embeds = PromptEmbeds(cloned_text_embeds)
|
prompt_embeds = PromptEmbeds(cloned_text_embeds)
|
||||||
|
|
||||||
if self.attention_mask is not None:
|
if self.attention_mask is not None:
|
||||||
prompt_embeds.attention_mask = self.attention_mask.clone()
|
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
|
||||||
|
prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask]
|
||||||
|
else:
|
||||||
|
prompt_embeds.attention_mask = self.attention_mask.clone()
|
||||||
return prompt_embeds
|
return prompt_embeds
|
||||||
|
|
||||||
def expand_to_batch(self, batch_size):
|
def expand_to_batch(self, batch_size):
|
||||||
@@ -89,7 +98,10 @@ class PromptEmbeds:
|
|||||||
if pe.pooled_embeds is not None:
|
if pe.pooled_embeds is not None:
|
||||||
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
|
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
|
||||||
if pe.attention_mask is not None:
|
if pe.attention_mask is not None:
|
||||||
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
|
||||||
|
pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask]
|
||||||
|
else:
|
||||||
|
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
|
||||||
return pe
|
return pe
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
@@ -108,7 +120,11 @@ class PromptEmbeds:
|
|||||||
if pe.pooled_embeds is not None:
|
if pe.pooled_embeds is not None:
|
||||||
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
|
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
|
||||||
if pe.attention_mask is not None:
|
if pe.attention_mask is not None:
|
||||||
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
|
||||||
|
for i, attn in enumerate(pe.attention_mask):
|
||||||
|
state_dict[f"attention_mask_{i}"] = attn.cpu()
|
||||||
|
else:
|
||||||
|
state_dict["attention_mask"] = pe.attention_mask.cpu()
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
save_file(state_dict, path)
|
save_file(state_dict, path)
|
||||||
|
|
||||||
@@ -122,7 +138,7 @@ class PromptEmbeds:
|
|||||||
state_dict = load_file(path, device='cpu')
|
state_dict = load_file(path, device='cpu')
|
||||||
text_embeds = []
|
text_embeds = []
|
||||||
pooled_embeds = None
|
pooled_embeds = None
|
||||||
attention_mask = None
|
attention_mask = []
|
||||||
for key in sorted(state_dict.keys()):
|
for key in sorted(state_dict.keys()):
|
||||||
if key.startswith("text_embed_"):
|
if key.startswith("text_embed_"):
|
||||||
text_embeds.append(state_dict[key])
|
text_embeds.append(state_dict[key])
|
||||||
@@ -130,19 +146,25 @@ class PromptEmbeds:
|
|||||||
text_embeds.append(state_dict[key])
|
text_embeds.append(state_dict[key])
|
||||||
elif key == "pooled_embed":
|
elif key == "pooled_embed":
|
||||||
pooled_embeds = state_dict[key]
|
pooled_embeds = state_dict[key]
|
||||||
|
elif key.startswith("attention_mask_"):
|
||||||
|
attention_mask.append(state_dict[key])
|
||||||
elif key == "attention_mask":
|
elif key == "attention_mask":
|
||||||
attention_mask = state_dict[key]
|
attention_mask.append(state_dict[key])
|
||||||
pe = cls(None)
|
pe = cls(None)
|
||||||
pe.text_embeds = text_embeds
|
pe.text_embeds = text_embeds
|
||||||
if len(text_embeds) == 1:
|
if len(text_embeds) == 1:
|
||||||
pe.text_embeds = text_embeds[0]
|
pe.text_embeds = text_embeds[0]
|
||||||
if pooled_embeds is not None:
|
if pooled_embeds is not None:
|
||||||
pe.pooled_embeds = pooled_embeds
|
pe.pooled_embeds = pooled_embeds
|
||||||
if attention_mask is not None:
|
if len(attention_mask) > 0:
|
||||||
pe.attention_mask = attention_mask
|
if len(attention_mask) == 1:
|
||||||
|
pe.attention_mask = attention_mask[0]
|
||||||
|
else:
|
||||||
|
pe.attention_mask = attention_mask
|
||||||
return pe
|
return pe
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EncodedPromptPair:
|
class EncodedPromptPair:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user