Preperation for audio

This commit is contained in:
Jaret Burkett
2025-09-02 07:26:50 -06:00
parent 0f2239ca23
commit 7040d8d73b
2 changed files with 42 additions and 10 deletions

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
import random import random
import torch import torch
import torchaudio
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
@@ -1073,6 +1074,15 @@ class GenerateImageConfig:
) )
else: else:
raise ValueError(f"Unsupported video format {self.output_ext}") raise ValueError(f"Unsupported video format {self.output_ext}")
elif self.output_ext in ['wav', 'mp3']:
# save audio file
torchaudio.save(
self.get_image_path(count, max_count),
image[0].to('cpu'),
sample_rate=48000,
format=None,
backend=None
)
else: else:
# TODO save image gen header info for A1111 and us, our seeds probably wont match # TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count)) image.save(self.get_image_path(count, max_count))

View File

@@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER:
class PromptEmbeds: class PromptEmbeds:
# text_embeds: torch.Tensor # text_embeds: torch.Tensor
# pooled_embeds: Union[torch.Tensor, None] # pooled_embeds: Union[torch.Tensor, None]
# attention_mask: Union[torch.Tensor, None] # attention_mask: Union[torch.Tensor, List[torch.Tensor], None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
if isinstance(args, list) or isinstance(args, tuple): if isinstance(args, list) or isinstance(args, tuple):
@@ -43,7 +43,10 @@ class PromptEmbeds:
if self.pooled_embeds is not None: if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None: if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(*args, **kwargs) if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask]
else:
self.attention_mask = self.attention_mask.to(*args, **kwargs)
return self return self
def detach(self): def detach(self):
@@ -55,7 +58,10 @@ class PromptEmbeds:
if new_embeds.pooled_embeds is not None: if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None: if new_embeds.attention_mask is not None:
new_embeds.attention_mask = new_embeds.attention_mask.detach() if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple):
new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask]
else:
new_embeds.attention_mask = new_embeds.attention_mask.detach()
return new_embeds return new_embeds
def clone(self): def clone(self):
@@ -69,7 +75,10 @@ class PromptEmbeds:
prompt_embeds = PromptEmbeds(cloned_text_embeds) prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None: if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone() if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask]
else:
prompt_embeds.attention_mask = self.attention_mask.clone()
return prompt_embeds return prompt_embeds
def expand_to_batch(self, batch_size): def expand_to_batch(self, batch_size):
@@ -89,7 +98,10 @@ class PromptEmbeds:
if pe.pooled_embeds is not None: if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None: if pe.attention_mask is not None:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1) if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask]
else:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe return pe
def save(self, path: str): def save(self, path: str):
@@ -108,7 +120,11 @@ class PromptEmbeds:
if pe.pooled_embeds is not None: if pe.pooled_embeds is not None:
state_dict["pooled_embed"] = pe.pooled_embeds.cpu() state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
if pe.attention_mask is not None: if pe.attention_mask is not None:
state_dict["attention_mask"] = pe.attention_mask.cpu() if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
for i, attn in enumerate(pe.attention_mask):
state_dict[f"attention_mask_{i}"] = attn.cpu()
else:
state_dict["attention_mask"] = pe.attention_mask.cpu()
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
save_file(state_dict, path) save_file(state_dict, path)
@@ -122,7 +138,7 @@ class PromptEmbeds:
state_dict = load_file(path, device='cpu') state_dict = load_file(path, device='cpu')
text_embeds = [] text_embeds = []
pooled_embeds = None pooled_embeds = None
attention_mask = None attention_mask = []
for key in sorted(state_dict.keys()): for key in sorted(state_dict.keys()):
if key.startswith("text_embed_"): if key.startswith("text_embed_"):
text_embeds.append(state_dict[key]) text_embeds.append(state_dict[key])
@@ -130,19 +146,25 @@ class PromptEmbeds:
text_embeds.append(state_dict[key]) text_embeds.append(state_dict[key])
elif key == "pooled_embed": elif key == "pooled_embed":
pooled_embeds = state_dict[key] pooled_embeds = state_dict[key]
elif key.startswith("attention_mask_"):
attention_mask.append(state_dict[key])
elif key == "attention_mask": elif key == "attention_mask":
attention_mask = state_dict[key] attention_mask.append(state_dict[key])
pe = cls(None) pe = cls(None)
pe.text_embeds = text_embeds pe.text_embeds = text_embeds
if len(text_embeds) == 1: if len(text_embeds) == 1:
pe.text_embeds = text_embeds[0] pe.text_embeds = text_embeds[0]
if pooled_embeds is not None: if pooled_embeds is not None:
pe.pooled_embeds = pooled_embeds pe.pooled_embeds = pooled_embeds
if attention_mask is not None: if len(attention_mask) > 0:
pe.attention_mask = attention_mask if len(attention_mask) == 1:
pe.attention_mask = attention_mask[0]
else:
pe.attention_mask = attention_mask
return pe return pe
class EncodedPromptPair: class EncodedPromptPair:
def __init__( def __init__(
self, self,