From 7040d8d73b46df091fc326a064db7605d3b7bd6f Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Tue, 2 Sep 2025 07:26:50 -0600 Subject: [PATCH] Preperation for audio --- toolkit/config_modules.py | 10 ++++++++++ toolkit/prompt_utils.py | 42 +++++++++++++++++++++++++++++---------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index b6c0a3d7..ae082c80 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -4,6 +4,7 @@ from typing import List, Optional, Literal, Tuple, Union, TYPE_CHECKING, Dict import random import torch +import torchaudio from toolkit.prompt_utils import PromptEmbeds @@ -1073,6 +1074,15 @@ class GenerateImageConfig: ) else: raise ValueError(f"Unsupported video format {self.output_ext}") + elif self.output_ext in ['wav', 'mp3']: + # save audio file + torchaudio.save( + self.get_image_path(count, max_count), + image[0].to('cpu'), + sample_rate=48000, + format=None, + backend=None + ) else: # TODO save image gen header info for A1111 and us, our seeds probably wont match image.save(self.get_image_path(count, max_count)) diff --git a/toolkit/prompt_utils.py b/toolkit/prompt_utils.py index b7e27cf9..b8a6f1e5 100644 --- a/toolkit/prompt_utils.py +++ b/toolkit/prompt_utils.py @@ -21,7 +21,7 @@ class ACTION_TYPES_SLIDER: class PromptEmbeds: # text_embeds: torch.Tensor # pooled_embeds: Union[torch.Tensor, None] - # attention_mask: Union[torch.Tensor, None] + # attention_mask: Union[torch.Tensor, List[torch.Tensor], None] def __init__(self, args: Union[Tuple[torch.Tensor], List[torch.Tensor], torch.Tensor], attention_mask=None) -> None: if isinstance(args, list) or isinstance(args, tuple): @@ -43,7 +43,10 @@ class PromptEmbeds: if self.pooled_embeds is not None: self.pooled_embeds = self.pooled_embeds.to(*args, **kwargs) if self.attention_mask is not None: - self.attention_mask = self.attention_mask.to(*args, **kwargs) + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + self.attention_mask = [t.to(*args, **kwargs) for t in self.attention_mask] + else: + self.attention_mask = self.attention_mask.to(*args, **kwargs) return self def detach(self): @@ -55,7 +58,10 @@ class PromptEmbeds: if new_embeds.pooled_embeds is not None: new_embeds.pooled_embeds = new_embeds.pooled_embeds.detach() if new_embeds.attention_mask is not None: - new_embeds.attention_mask = new_embeds.attention_mask.detach() + if isinstance(new_embeds.attention_mask, list) or isinstance(new_embeds.attention_mask, tuple): + new_embeds.attention_mask = [t.detach() for t in new_embeds.attention_mask] + else: + new_embeds.attention_mask = new_embeds.attention_mask.detach() return new_embeds def clone(self): @@ -69,7 +75,10 @@ class PromptEmbeds: prompt_embeds = PromptEmbeds(cloned_text_embeds) if self.attention_mask is not None: - prompt_embeds.attention_mask = self.attention_mask.clone() + if isinstance(self.attention_mask, list) or isinstance(self.attention_mask, tuple): + prompt_embeds.attention_mask = [t.clone() for t in self.attention_mask] + else: + prompt_embeds.attention_mask = self.attention_mask.clone() return prompt_embeds def expand_to_batch(self, batch_size): @@ -89,7 +98,10 @@ class PromptEmbeds: if pe.pooled_embeds is not None: pe.pooled_embeds = pe.pooled_embeds.expand(batch_size, -1) if pe.attention_mask is not None: - pe.attention_mask = pe.attention_mask.expand(batch_size, -1) + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + pe.attention_mask = [t.expand(batch_size, -1) for t in pe.attention_mask] + else: + pe.attention_mask = pe.attention_mask.expand(batch_size, -1) return pe def save(self, path: str): @@ -108,7 +120,11 @@ class PromptEmbeds: if pe.pooled_embeds is not None: state_dict["pooled_embed"] = pe.pooled_embeds.cpu() if pe.attention_mask is not None: - state_dict["attention_mask"] = pe.attention_mask.cpu() + if isinstance(pe.attention_mask, list) or isinstance(pe.attention_mask, tuple): + for i, attn in enumerate(pe.attention_mask): + state_dict[f"attention_mask_{i}"] = attn.cpu() + else: + state_dict["attention_mask"] = pe.attention_mask.cpu() os.makedirs(os.path.dirname(path), exist_ok=True) save_file(state_dict, path) @@ -122,7 +138,7 @@ class PromptEmbeds: state_dict = load_file(path, device='cpu') text_embeds = [] pooled_embeds = None - attention_mask = None + attention_mask = [] for key in sorted(state_dict.keys()): if key.startswith("text_embed_"): text_embeds.append(state_dict[key]) @@ -130,19 +146,25 @@ class PromptEmbeds: text_embeds.append(state_dict[key]) elif key == "pooled_embed": pooled_embeds = state_dict[key] + elif key.startswith("attention_mask_"): + attention_mask.append(state_dict[key]) elif key == "attention_mask": - attention_mask = state_dict[key] + attention_mask.append(state_dict[key]) pe = cls(None) pe.text_embeds = text_embeds if len(text_embeds) == 1: pe.text_embeds = text_embeds[0] if pooled_embeds is not None: pe.pooled_embeds = pooled_embeds - if attention_mask is not None: - pe.attention_mask = attention_mask + if len(attention_mask) > 0: + if len(attention_mask) == 1: + pe.attention_mask = attention_mask[0] + else: + pe.attention_mask = attention_mask return pe + class EncodedPromptPair: def __init__( self,