Preperation for audio

This commit is contained in:
Jaret Burkett
2025-09-02 07:26:50 -06:00
parent 0f2239ca23
commit 7040d8d73b
2 changed files with 42 additions and 10 deletions

View File

@@ -4,6 +4,7 @@ from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict
import random
import torch
import torchaudio
from toolkit.prompt_utils import PromptEmbeds
@@ -1073,6 +1074,15 @@ class GenerateImageConfig:
)
else:
raise ValueError(f"Unsupported video format {self.output_ext}")
elif self.output_ext in ['wav', 'mp3']:
# save audio file
torchaudio.save(
self.get_image_path(count, max_count),
image[0].to('cpu'),
sample_rate=48000,
format=None,
backend=None
)
else:
# TODO save image gen header info for A1111 and us, our seeds probably wont match
image.save(self.get_image_path(count, max_count))

View File

@@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER:
class PromptEmbeds:
# text_embeds: torch.Tensor
# pooled_embeds: Union[torch.Tensor, None]
# attention_mask: Union[torch.Tensor, None]
# attention_mask: Union[torch.Tensor, List[torch.Tensor], None]
def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None:
if isinstance(args, list) or isinstance(args, tuple):
@@ -43,7 +43,10 @@ class PromptEmbeds:
if self.pooled_embeds is not None:
self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs)
if self.attention_mask is not None:
self.attention_mask = self.attention_mask.to(*args, **kwargs)
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask]
else:
self.attention_mask = self.attention_mask.to(*args, **kwargs)
return self
def detach(self):
@@ -55,7 +58,10 @@ class PromptEmbeds:
if new_embeds.pooled_embeds is not None:
new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach()
if new_embeds.attention_mask is not None:
new_embeds.attention_mask = new_embeds.attention_mask.detach()
if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple):
new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask]
else:
new_embeds.attention_mask = new_embeds.attention_mask.detach()
return new_embeds
def clone(self):
@@ -69,7 +75,10 @@ class PromptEmbeds:
prompt_embeds = PromptEmbeds(cloned_text_embeds)
if self.attention_mask is not None:
prompt_embeds.attention_mask = self.attention_mask.clone()
if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple):
prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask]
else:
prompt_embeds.attention_mask = self.attention_mask.clone()
return prompt_embeds
def expand_to_batch(self, batch_size):
@@ -89,7 +98,10 @@ class PromptEmbeds:
if pe.pooled_embeds is not None:
pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1)
if pe.attention_mask is not None:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask]
else:
pe.attention_mask = pe.attention_mask.expand(batch_size, -1)
return pe
def save(self, path: str):
@@ -108,7 +120,11 @@ class PromptEmbeds:
if pe.pooled_embeds is not None:
state_dict["pooled_embed"] = pe.pooled_embeds.cpu()
if pe.attention_mask is not None:
state_dict["attention_mask"] = pe.attention_mask.cpu()
if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple):
for i, attn in enumerate(pe.attention_mask):
state_dict[f"attention_mask_{i}"] = attn.cpu()
else:
state_dict["attention_mask"] = pe.attention_mask.cpu()
os.makedirs(os.path.dirname(path), exist_ok=True)
save_file(state_dict, path)
@@ -122,7 +138,7 @@ class PromptEmbeds:
state_dict = load_file(path, device='cpu')
text_embeds = []
pooled_embeds = None
attention_mask = None
attention_mask = []
for key in sorted(state_dict.keys()):
if key.startswith("text_embed_"):
text_embeds.append(state_dict[key])
@@ -130,19 +146,25 @@ class PromptEmbeds:
text_embeds.append(state_dict[key])
elif key == "pooled_embed":
pooled_embeds = state_dict[key]
elif key.startswith("attention_mask_"):
attention_mask.append(state_dict[key])
elif key == "attention_mask":
attention_mask = state_dict[key]
attention_mask.append(state_dict[key])
pe = cls(None)
pe.text_embeds = text_embeds
if len(text_embeds) == 1:
pe.text_embeds = text_embeds[0]
if pooled_embeds is not None:
pe.pooled_embeds = pooled_embeds
if attention_mask is not None:
pe.attention_mask = attention_mask
if len(attention_mask) > 0:
if len(attention_mask) == 1:
pe.attention_mask = attention_mask[0]
else:
pe.attention_mask = attention_mask
return pe
class EncodedPromptPair:
def __init__(
self,