From 4056f54f8671b9369c7295a9ccc74b3ea7cf5555 Mon Sep 17 00:00:00 2001 From: WildAi <2853742+wildminder@users.noreply.github.com> Date: Wed, 27 Aug 2025 15:51:44 +0300 Subject: [PATCH] init --- .gitignore | 163 +++ vibevoice/__init__.py | 0 vibevoice/configs/qwen2.5_1.5b_64k.json | 112 ++ vibevoice/configs/qwen2.5_7b_32k.json | 113 ++ vibevoice/modular/__init__.py | 0 vibevoice/modular/configuration_vibevoice.py | 248 ++++ vibevoice/modular/modeling_vibevoice.py | 488 +++++++ .../modular/modeling_vibevoice_inference.py | 715 ++++++++++ .../modular_vibevoice_diffusion_head.py | 287 ++++ .../modular_vibevoice_text_tokenizer.py | 214 +++ .../modular/modular_vibevoice_tokenizer.py | 1195 +++++++++++++++++ vibevoice/modular/streamer.py | 264 ++++ vibevoice/processor/__init__.py | 0 vibevoice/processor/vibevoice_processor.py | 677 ++++++++++ .../vibevoice_tokenizer_processor.py | 483 +++++++ vibevoice/schedule/__init__.py | 0 vibevoice/schedule/dpm_solver.py | 1065 +++++++++++++++ vibevoice/schedule/timestep_sampler.py | 19 + vibevoice/scripts/__init__.py | 0 ...ert_nnscaler_checkpoint_to_transformers.py | 166 +++ 20 files changed, 6209 insertions(+) create mode 100644 .gitignore create mode 100644 vibevoice/__init__.py create mode 100644 vibevoice/configs/qwen2.5_1.5b_64k.json create mode 100644 vibevoice/configs/qwen2.5_7b_32k.json create mode 100644 vibevoice/modular/__init__.py create mode 100644 vibevoice/modular/configuration_vibevoice.py create mode 100644 vibevoice/modular/modeling_vibevoice.py create mode 100644 vibevoice/modular/modeling_vibevoice_inference.py create mode 100644 vibevoice/modular/modular_vibevoice_diffusion_head.py create mode 100644 vibevoice/modular/modular_vibevoice_text_tokenizer.py create mode 100644 vibevoice/modular/modular_vibevoice_tokenizer.py create mode 100644 vibevoice/modular/streamer.py create mode 100644 vibevoice/processor/__init__.py create mode 100644 vibevoice/processor/vibevoice_processor.py create mode 100644 vibevoice/processor/vibevoice_tokenizer_processor.py create mode 100644 vibevoice/schedule/__init__.py create mode 100644 vibevoice/schedule/dpm_solver.py create mode 100644 vibevoice/schedule/timestep_sampler.py create mode 100644 vibevoice/scripts/__init__.py create mode 100644 vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..78da158 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.github +.idea +.Python +__pycache__ +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/vibevoice/__init__.py b/vibevoice/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vibevoice/configs/qwen2.5_1.5b_64k.json b/vibevoice/configs/qwen2.5_1.5b_64k.json new file mode 100644 index 0000000..febd05c --- /dev/null +++ b/vibevoice/configs/qwen2.5_1.5b_64k.json @@ -0,0 +1,112 @@ +{ + "_attn_implementation_autoset": true, + "acoustic_vae_dim": 64, + "acoustic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "decoder_depths": null, + "decoder_n_filters": 32, + "decoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0.5, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibepod_acoustic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "gaussian", + "vae_dim": 64, + "weight_init_value": 0.01 + }, + "decoder_config": { + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 1536, + "initializer_range": 0.02, + "intermediate_size": 8960, + "max_position_embeddings": 65536, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 12, + "num_hidden_layers": 28, + "num_key_value_heads": 2, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": true, + "torch_dtype": "bfloat16", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 + }, + "diffusion_head_config": { + "ddpm_batch_mul": 4, + "ddpm_beta_schedule": "cosine", + "ddpm_num_inference_steps": 20, + "ddpm_num_steps": 1000, + "diffusion_type": "ddpm", + "head_ffn_ratio": 3.0, + "head_layers": 4, + "hidden_size": 1536, + "latent_size": 64, + "model_type": "vibepod_diffusion_head", + "prediction_type": "v_prediction", + "rms_norm_eps": 1e-05, + "speech_vae_dim": 64 + }, + "model_type": "vibepod", + "semantic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibepod_semantic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "none", + "vae_dim": 128, + "weight_init_value": 0.01 + }, + "semantic_vae_dim": 128, + "torch_dtype": "bfloat16" +} diff --git a/vibevoice/configs/qwen2.5_7b_32k.json b/vibevoice/configs/qwen2.5_7b_32k.json new file mode 100644 index 0000000..d39952c --- /dev/null +++ b/vibevoice/configs/qwen2.5_7b_32k.json @@ -0,0 +1,113 @@ +{ + "_attn_implementation_autoset": true, + "acoustic_vae_dim": 64, + "acoustic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "decoder_depths": null, + "decoder_n_filters": 32, + "decoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0.5, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibepod_acoustic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "gaussian", + "vae_dim": 64, + "weight_init_value": 0.01 + }, + "decoder_config": { + "attention_dropout": 0.0, + "hidden_act": "silu", + "hidden_size": 3584, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": true, + "use_mrope": false, + "use_sliding_window": false, + "vocab_size": 152064 + }, + "diffusion_head_config": { + "ddpm_batch_mul": 4, + "ddpm_beta_schedule": "cosine", + "ddpm_num_inference_steps": 20, + "ddpm_num_steps": 1000, + "diffusion_type": "ddpm", + "head_ffn_ratio": 3.0, + "head_layers": 4, + "hidden_size": 3584, + "latent_size": 64, + "model_type": "vibepod_diffusion_head", + "prediction_type": "v_prediction", + "rms_norm_eps": 1e-05, + "speech_vae_dim": 64 + }, + "model_type": "vibepod", + "semantic_tokenizer_config": { + "causal": true, + "channels": 1, + "conv_bias": true, + "conv_norm": "none", + "corpus_normalize": 0.0, + "disable_last_norm": true, + "encoder_depths": "3-3-3-3-3-3-8", + "encoder_n_filters": 32, + "encoder_ratios": [ + 8, + 5, + 5, + 4, + 2, + 2 + ], + "fix_std": 0, + "layer_scale_init_value": 1e-06, + "layernorm": "RMSNorm", + "layernorm_elementwise_affine": true, + "layernorm_eps": 1e-05, + "mixer_layer": "depthwise_conv", + "model_type": "vibepod_semantic_tokenizer", + "pad_mode": "constant", + "std_dist_type": "none", + "vae_dim": 128, + "weight_init_value": 0.01 + }, + "semantic_vae_dim": 128, + "torch_dtype": "bfloat16" +} diff --git a/vibevoice/modular/__init__.py b/vibevoice/modular/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vibevoice/modular/configuration_vibevoice.py b/vibevoice/modular/configuration_vibevoice.py new file mode 100644 index 0000000..fcffcb9 --- /dev/null +++ b/vibevoice/modular/configuration_vibevoice.py @@ -0,0 +1,248 @@ +""" VibeVoice_AcousticTokenizer model configuration""" + +from typing import Dict, List, Optional, Tuple + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config + +logger = logging.get_logger(__name__) + + +class VibeVoiceAcousticTokenizerConfig(PretrainedConfig): + model_type = "vibevoice_acoustic_tokenizer" + + def __init__( + self, + channels: int = 1, + corpus_normalize: float = 0.0, + causal: bool = True, + vae_dim: int = 64, + fix_std: float = 0.5, + std_dist_type: str = 'gaussian', + # common + mixer_layer: str = 'depthwise_conv', + conv_norm: str = 'none', + pad_mode: str = 'constant', + disable_last_norm: bool = True, + layernorm: str = 'RMSNorm', + layernorm_eps: float = 1e-5, + layernorm_elementwise_affine: bool = True, + conv_bias: bool = True, + layer_scale_init_value: float = 1e-6, + weight_init_value: float = 1e-2, + # encoder specific + encoder_n_filters: int = 32, + encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], + encoder_depths: str = "3-3-3-3-3-3-8", + # decoder specific + decoder_n_filters: int = 32, + decoder_ratios: Optional[List[int]] = None, # if None, same as encoder + decoder_depths: Optional[str] = None, + **kwargs + ): + super().__init__(**kwargs) + self.channels = channels + self.corpus_normalize = corpus_normalize + self.causal = causal + self.vae_dim = vae_dim + self.fix_std = fix_std + self.std_dist_type = std_dist_type + + # common parameters + self.conv_norm = conv_norm + self.pad_mode = pad_mode + self.layernorm_eps = layernorm_eps + self.disable_last_norm = disable_last_norm + self.layernorm = layernorm + self.layernorm_elementwise_affine = layernorm_elementwise_affine + self.conv_bias = conv_bias + self.layer_scale_init_value = layer_scale_init_value + self.weight_init_value = weight_init_value + self.mixer_layer = mixer_layer + + # encoder specific parameters + self.encoder_n_filters = encoder_n_filters + self.encoder_ratios = encoder_ratios + self.encoder_depths = encoder_depths + + # decoder specific parameters + self.decoder_ratios = decoder_ratios if decoder_ratios is not None else encoder_ratios + self.decoder_n_filters = decoder_n_filters + self.decoder_depths = decoder_depths + + +class VibeVoiceSemanticTokenizerConfig(PretrainedConfig): + model_type = "vibevoice_semantic_tokenizer" + + def __init__( + self, + channels: int = 1, + corpus_normalize: float = 0.0, + causal: bool = True, + vae_dim: int = 64, + fix_std: float = 0, + std_dist_type: str = 'none', + # common + mixer_layer: str = 'depthwise_conv', + conv_norm: str = 'none', + pad_mode: str = 'constant', + disable_last_norm: bool = True, + layernorm: str = 'RMSNorm', + layernorm_eps: float = 1e-5, + layernorm_elementwise_affine: bool = True, + conv_bias: bool = True, + layer_scale_init_value: float = 1e-6, + weight_init_value: float = 1e-2, + # encoder specific + encoder_n_filters: int = 32, + encoder_ratios: Optional[List[int]] = [8,5,5,4,2,2], + encoder_depths: str = "3-3-3-3-3-3-8", + **kwargs + ): + super().__init__(**kwargs) + self.channels = channels + self.corpus_normalize = corpus_normalize + self.causal = causal + self.vae_dim = vae_dim + self.fix_std = fix_std + self.std_dist_type = std_dist_type + + # common parameters + self.conv_norm = conv_norm + self.pad_mode = pad_mode + self.layernorm_eps = layernorm_eps + self.disable_last_norm = disable_last_norm + self.layernorm = layernorm + self.layernorm_elementwise_affine = layernorm_elementwise_affine + self.conv_bias = conv_bias + self.layer_scale_init_value = layer_scale_init_value + self.weight_init_value = weight_init_value + self.mixer_layer = mixer_layer + + # encoder specific parameters + self.encoder_n_filters = encoder_n_filters + self.encoder_ratios = encoder_ratios + self.encoder_depths = encoder_depths + + +class VibeVoiceDiffusionHeadConfig(PretrainedConfig): + model_type = "vibevoice_diffusion_head" + + def __init__( + self, + hidden_size=768, + head_layers=4, + head_ffn_ratio=3.0, + rms_norm_eps=1e-5, + latent_size=64, + speech_vae_dim=None, + prediction_type="v_prediction", + diffusion_type="ddpm", + ddpm_num_steps=1000, + ddpm_num_inference_steps=20, + ddpm_beta_schedule="cosine", + ddpm_batch_mul=4, + **kwargs + ): + self.hidden_size = hidden_size + self.head_layers = head_layers + self.head_ffn_ratio = head_ffn_ratio + self.rms_norm_eps = rms_norm_eps + self.latent_size = latent_size + self.speech_vae_dim = speech_vae_dim + self.prediction_type = prediction_type + self.diffusion_type = diffusion_type + self.ddpm_num_steps = ddpm_num_steps + self.ddpm_num_inference_steps = ddpm_num_inference_steps + self.ddpm_beta_schedule = ddpm_beta_schedule + self.ddpm_batch_mul = ddpm_batch_mul + + super().__init__(**kwargs) + +class VibeVoiceConfig(PretrainedConfig): + model_type = "vibevoice" + is_composition = True + sub_configs = { + "acoustic_tokenizer_config": VibeVoiceAcousticTokenizerConfig, + "semantic_tokenizer_config": VibeVoiceSemanticTokenizerConfig, + "decoder_config": Qwen2Config, + "diffusion_head_config": VibeVoiceDiffusionHeadConfig, + } + # keys_to_ignore_at_inference = ["past_key_values"] + # Default tensor parallel plan for base model `Qwen2` + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + def __init__( + self, + acoustic_tokenizer_config=None, + semantic_tokenizer_config=None, + decoder_config=None, + diffusion_head_config=None, + **kwargs + ): + + # kwargs["_attn_implementation"] = "flash_attention_2" + kwargs["_attn_implementation_autoset"] = False + + if acoustic_tokenizer_config is None: + self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"]() + elif isinstance(acoustic_tokenizer_config, dict): + acoustic_tokenizer_config["model_type"] = "vibevoice_acoustic_tokenizer" + self.acoustic_tokenizer_config = self.sub_configs["acoustic_tokenizer_config"](**acoustic_tokenizer_config) + elif isinstance(acoustic_tokenizer_config, VibeVoiceAcousticTokenizerConfig): + # If an instance of the config class is provided + self.acoustic_tokenizer_config = acoustic_tokenizer_config + + if semantic_tokenizer_config is None: + self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"]() + elif isinstance(semantic_tokenizer_config, dict): + semantic_tokenizer_config["model_type"] = "vibevoice_semantic_tokenizer" + self.semantic_tokenizer_config = self.sub_configs["semantic_tokenizer_config"](**semantic_tokenizer_config) + elif isinstance(semantic_tokenizer_config, VibeVoiceSemanticTokenizerConfig): + # If an instance of the config class is provided + self.semantic_tokenizer_config = semantic_tokenizer_config + + if decoder_config is None: + self.decoder_config = self.sub_configs["decoder_config"]() + elif isinstance(decoder_config, dict): + # If a dictionary is provided, instantiate the config class with it + # self.decoder_config = self.sub_configs["decoder_config"](**decoder_config) + if decoder_config.get("model_type", '') == "qwen2": + self.decoder_config = Qwen2Config(**decoder_config) + else: + raise ValueError(f"Unsupported decoder model type: {decoder_config.get('model_type', '')}") + elif isinstance(decoder_config, (Qwen2Config,)): + # If an instance of the config class is provided + self.decoder_config = decoder_config + + if diffusion_head_config is None: + self.diffusion_head_config = self.sub_configs["diffusion_head_config"]() + elif isinstance(diffusion_head_config, dict): + diffusion_head_config["model_type"] = "vibevoice_diffusion_head" + self.diffusion_head_config = self.sub_configs["diffusion_head_config"](**diffusion_head_config) + elif isinstance(diffusion_head_config, VibeVoiceDiffusionHeadConfig): + # If an instance of the config class is provided + self.diffusion_head_config = diffusion_head_config + + # other parameters + self.acoustic_vae_dim = getattr(self.acoustic_tokenizer_config, 'vae_dim', 64) + self.semantic_vae_dim = getattr(self.semantic_tokenizer_config, 'vae_dim', 128) + + super().__init__(**kwargs) + +__all__ = [ + "VibeVoiceAcousticTokenizerConfig", + "VibeVoiceSemanticTokenizerConfig", + "VibeVoiceDiffusionHeadConfig", + "VibeVoiceConfig" +] \ No newline at end of file diff --git a/vibevoice/modular/modeling_vibevoice.py b/vibevoice/modular/modeling_vibevoice.py new file mode 100644 index 0000000..016a389 --- /dev/null +++ b/vibevoice/modular/modeling_vibevoice.py @@ -0,0 +1,488 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from transformers.models.auto import AutoModel, AutoModelForCausalLM + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import CausalLMOutput, BaseModelOutputWithPast, ModelOutput +from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + + +from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel +from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead +from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler + +from .configuration_vibevoice import VibeVoiceConfig + + +logger = logging.get_logger(__name__) + +if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + +@dataclass +class VibeVoiceCausalLMOutputWithPast(ModelOutput): + loss: Optional[torch.FloatTensor] = None + diffusion_loss: Optional[torch.FloatTensor] = None + speech_token_num: Optional[int] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + attentions: Optional[Tuple[torch.FloatTensor, ...]] = None + + +@dataclass +class VibeVoiceGenerationOutput(ModelOutput): + """ + Output type for VibeVoice generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. + speech_outputs (`List[torch.FloatTensor]`, *optional*): + List of generated speech waveforms or latents for each speech segment. + """ + sequences: torch.LongTensor = None + speech_outputs: Optional[List[torch.FloatTensor]] = None + + +class SpeechConnector(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc1 = nn.Linear(input_dim, output_dim) + self.norm = LlamaRMSNorm(output_dim, eps=1e-6) + self.fc2 = nn.Linear(output_dim, output_dim) + + def forward(self, features, **kwargs): + x = self.fc1(features) + x = self.norm(x) + x = self.fc2(x) + return x + + +# @auto_docstring +class VibeVoicePreTrainedModel(PreTrainedModel): + config_class = VibeVoiceConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module): + if isinstance(module, VibeVoiceDiffusionHead): + module.initialize_weights() + return + + # Use the language model's initializer_range if available + if hasattr(self.config, 'language_model_config') and hasattr(self.config.language_model_config, 'initializer_range'): + std = self.config.language_model_config.initializer_range + elif hasattr(self.config, 'decoder_config') and hasattr(self.config.decoder_config, 'initializer_range'): + std = self.config.decoder_config.initializer_range + else: + std = 0.02 # Default value + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + +# @auto_docstring +class VibeVoiceModel(VibeVoicePreTrainedModel): + def __init__(self, config): + super().__init__(config) + + if hasattr(config, 'torch_dtype') and config.torch_dtype is not None: + if isinstance(config.torch_dtype, str): + dtype = getattr(torch, config.torch_dtype) + else: + dtype = config.torch_dtype + else: + dtype = torch.float32 + + # Initialize Qwen2 model for language modeling + lm_config = config.decoder_config + self.language_model = AutoModel.from_config(lm_config) + + # Initialize speech components if needed + self.acoustic_tokenizer = AutoModel.from_config(config.acoustic_tokenizer_config).to(dtype) + self.semantic_tokenizer = AutoModel.from_config(config.semantic_tokenizer_config).to(dtype) + + self.acoustic_connector = SpeechConnector(config.acoustic_vae_dim, lm_config.hidden_size).to(dtype) + self.semantic_connector = SpeechConnector(config.semantic_vae_dim, lm_config.hidden_size).to(dtype) + + # Register scaling factors as buffers - use 1D tensors for FSDP compatibility + self.register_buffer('speech_scaling_factor', torch.tensor(float('nan'))) + self.register_buffer('speech_bias_factor', torch.tensor(float('nan'))) + + # Initialize prediction head for speech generation + self.prediction_head = AutoModel.from_config(config.diffusion_head_config).to(dtype) + + # Initialize noise scheduler + self.noise_scheduler = DPMSolverMultistepScheduler( + num_train_timesteps=config.diffusion_head_config.ddpm_num_steps, + beta_schedule=config.diffusion_head_config.ddpm_beta_schedule, + prediction_type=config.diffusion_head_config.prediction_type + ) + + def get_input_embeddings(self): + if hasattr(self.language_model, 'embed_tokens'): + # If the language model has an embed_tokens attribute, return it + return self.language_model.embed_tokens + + for name, attr in self.language_model.fullmap.items(): # parallel by nnscaler, the name is changed + if attr.orig_name == 'embed_tokens.weight': + return getattr(self.language_model, name) + assert False, 'should not arrive here' + + def set_input_embeddings(self, value): + self.language_model.embed_tokens = value + + def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): + """Set the speech tokenizers used for encoding and decoding speech.""" + self.acoustic_tokenizer = acoustic_tokenizer + self.semantic_tokenizer = semantic_tokenizer + + # Reset the encoder to evaluation mode + if self.acoustic_tokenizer is not None: + self.acoustic_tokenizer.eval() + + if self.semantic_tokenizer is not None: + self.semantic_tokenizer.eval() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Forward through language model + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + if not return_dict: + return outputs + + return BaseModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class VibeVoiceForConditionalGeneration(VibeVoicePreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = VibeVoiceModel(config) + self.vocab_size = config.decoder_config.vocab_size + self.lm_head = nn.Linear(config.decoder_config.hidden_size, self.vocab_size, bias=False) + + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_decoder(self, decoder): + self.model.language_model = decoder + + def get_decoder(self): + return self.model.language_model + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + """ + if getattr(self.config.decoder_config, 'tie_word_embeddings', False): + # The standard PreTrainedModel method will handle the tying. + # It typically does a simple parameter object assignment, which is + # CORRECT to do BEFORE FSDP wraps the model. + output_embeddings = self.get_output_embeddings() + input_embeddings = self.get_input_embeddings() + if hasattr(input_embeddings, 'weight'): + output_embeddings.weight = input_embeddings.weight + else: + # maybe returned input_embeddings a tensor directly + output_embeddings.weight = input_embeddings + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + (0, output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0]), + "constant", + 0, + ) + print("✅ Tied input and output embeddings using standard assignment.") + else: + print("ℹ️ tie_word_embeddings is False, not tying weights.") + + # Also, ensure set_output_embeddings is safe, though your implementation looks okay. + # The key is to avoid calling it after accelerator.prepare(). + def set_output_embeddings(self, new_embeddings): + # Your current implementation using data.copy_ is good practice, + # but the best way is to not call this after prepare(). + self.lm_head = new_embeddings + + def forward_speech_features( + self, + speech_tensors=None, + speech_masks=None, + speech_type="audio", + return_unmask=False + ): + if speech_tensors is None: + # Use config to get vae_dim instead of non-existent self.args + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + audio_features = torch.zeros(1, 1, vae_dim).to(self.get_input_embeddings().weight) + connect_features = self.model.acoustic_connector(audio_features) + return audio_features, connect_features + else: + with torch.no_grad(): + if speech_type == "audio": + with torch.no_grad(): + frames = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1))[0][0] + audio_tokens = frames.sample(self.model.acoustic_tokenizer.std_dist_type)[0] + + elif speech_type == "vae": + # Use config to get vae_dim instead of non-existent self.args + vae_dim = self.config.acoustic_tokenizer_config.vae_dim + speech_mode = speech_tensors.reshape(speech_tensors.size(0), -1, vae_dim) + + # gaussian sample from the speech_mode + batch_size = speech_mode.size(0) + value = self.model.acoustic_tokenizer.fix_std / 0.8 + std = torch.randn(batch_size, dtype=speech_mode.dtype, device=speech_mode.device) * value + std = std.view(-1, *[1] * (speech_mode.dim() - 1)) + audio_tokens = speech_mode + std * torch.randn(speech_mode.shape).to(speech_mode) + else: + raise NotImplementedError(f"Speech type {speech_type} not implemented") + + if torch.isnan(self.model.speech_scaling_factor) or torch.isnan(self.model.speech_bias_factor): + scaling_factor = 1. / audio_tokens[speech_masks].flatten().std() + bias_factor = -audio_tokens[speech_masks].flatten().mean() + + # Only use distributed operations if the process group is initialized + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(scaling_factor, op=dist.ReduceOp.SUM) + dist.all_reduce(bias_factor, op=dist.ReduceOp.SUM) + world_size = dist.get_world_size() + self.model.speech_scaling_factor.copy_(scaling_factor / world_size) + self.model.speech_bias_factor.copy_(bias_factor / world_size) + print(f"Speech scaling factor (distributed): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) + else: + # Single process case + self.model.speech_scaling_factor.copy_(scaling_factor) + self.model.speech_bias_factor.copy_(bias_factor) + print(f"Speech scaling factor (single process): {self.model.speech_scaling_factor}, bias factor: {self.model.speech_bias_factor}", flush=True) + + audio_features = (audio_tokens + self.model.speech_bias_factor) * self.model.speech_scaling_factor + + connect_features = self.model.acoustic_connector(audio_features) + if return_unmask: + return audio_features, connect_features + return audio_features[speech_masks], connect_features[speech_masks] + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + # New arguments for speech processing and loss calculation + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speeches_loss_input: Optional[torch.FloatTensor] = None, + speech_semantic_tensors: Optional[torch.FloatTensor] = None, + acoustic_input_mask: Optional[torch.BoolTensor] = None, + acoustic_loss_mask: Optional[torch.BoolTensor] = None, + ddpm_batch_mul: int = 1, + **kwargs: Optional[Dict[str, Union[torch.Tensor, str]]], + ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + x = self.get_input_embeddings()(input_ids) + + semantic_speech_all_connect_features = self.model.semantic_connector(speech_semantic_tensors) + if speeches_loss_input is not None: + # only part audio need diffuse + speech_all_features, speech_all_connect_features = self.forward_speech_features( + speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + return_unmask=True + ) + if speech_tensors is not None: + if semantic_speech_all_connect_features is not None: + x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + semantic_speech_all_connect_features[speech_masks] + else: + x[acoustic_input_mask] = speech_all_connect_features[speech_masks] + speech_features = speech_all_features[speeches_loss_input.unsqueeze(-1) & speech_masks] # only part audio need diffuse + speech_connect_features = speech_all_connect_features[speeches_loss_input.unsqueeze(-1) & speech_masks] + else: + speech_features, speech_connect_features = self.forward_speech_features( + speech_tensors=speech_tensors.type_as(x) if speech_tensors is not None else None, + speech_masks=speech_masks, + speech_type=kwargs.get("speech_type", "audio"), + ) + if speech_tensors is not None: + x[acoustic_input_mask] = speech_connect_features + + outputs = self.model( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=x, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=False, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + # logits = logits.float() + + loss = None + if labels is not None: + # The custom CE loss with masking is calculated in the training script. + # We leave the standard loss calculation here as None. + pass + + # --- Diffusion Loss Calculation --- + diffusion_loss = None + # This block is executed only if we are in a context that involves speech. + if speech_tensors is not None and acoustic_loss_mask.sum().item() > 0: + condition_features = hidden_states[acoustic_loss_mask] + + speech_len, latent_size = speech_features.shape + + noise = torch.randn( + (speech_len * ddpm_batch_mul, latent_size), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + + timesteps = torch.multinomial( + torch.ones(self.config.diffusion_head_config.ddpm_num_steps), + speech_len * ddpm_batch_mul, + replacement=True, + ).to(hidden_states.device) + + speech_features_repeated = speech_features.repeat_interleave(ddpm_batch_mul, dim=0) + condition_features_repeated = condition_features.repeat_interleave(ddpm_batch_mul, dim=0) + + noisy_speech_features = self.model.noise_scheduler.add_noise( + speech_features_repeated, noise, timesteps + ) + + model_output = self.model.prediction_head( + noisy_speech_features, + timesteps.type_as(x), + condition_features_repeated + ) + + prediction_type = self.config.diffusion_head_config.prediction_type + if prediction_type == "epsilon": + target_for_loss = noise + elif prediction_type == "v_prediction": + target_for_loss = self.model.noise_scheduler.get_velocity( + speech_features_repeated, noise, timesteps + ) + else: + raise NotImplementedError(f"Prediction type {prediction_type} not implemented") + + diffusion_loss = F.mse_loss(model_output.float(), target_for_loss.float(), reduction='sum') + if latent_size > 0 and ddpm_batch_mul > 0: + diffusion_loss = diffusion_loss / latent_size / ddpm_batch_mul + else: + diffusion_loss = torch.tensor(0.0, device=diffusion_loss.device) + + else: + # Dummy loss for DDP to work when there are no speech samples in a batch, + # but we are in a speech context. + diffusion_loss = sum(p.sum() for p in self.model.prediction_head.parameters()) * 0.0 + diffusion_loss += sum(p.sum() for p in self.model.acoustic_connector.parameters()) * 0.0 + diffusion_loss += sum(p.sum() for p in self.model.semantic_connector.parameters()) * 0.0 + # --- End Diffusion Loss Calculation --- + + if not return_dict: + output = (logits, speech_len) + outputs.to_tuple()[1:] + return (loss, diffusion_loss) + output + + return VibeVoiceCausalLMOutputWithPast( + loss=loss, + diffusion_loss=diffusion_loss, + speech_token_num=speech_len if speech_tensors is not None else 0, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +AutoModel.register(VibeVoiceConfig, VibeVoiceModel) +AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGeneration) + +__all__ = [ + "VibeVoiceModel", + "VibeVoicePreTrainedModel", + "VibeVoiceForConditionalGeneration", + "VibeVoiceCausalLMOutputWithPast", + "VibeVoiceGenerationOutput", +] \ No newline at end of file diff --git a/vibevoice/modular/modeling_vibevoice_inference.py b/vibevoice/modular/modeling_vibevoice_inference.py new file mode 100644 index 0000000..7e10af4 --- /dev/null +++ b/vibevoice/modular/modeling_vibevoice_inference.py @@ -0,0 +1,715 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union, Callable +from tqdm import tqdm +import torch +import torch.nn as nn + +from transformers.models.auto import AutoModel, AutoModelForCausalLM + +from transformers.generation import GenerationMixin, GenerationConfig, LogitsProcessor, LogitsProcessorList, StoppingCriteriaList +from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput +from transformers import modeling_utils +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.utils import logging + + +# from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceAcousticTokenizerModel, VibeVoiceSemanticTokenizerModel +from .modular_vibevoice_tokenizer import VibeVoiceTokenizerStreamingCache, VibeVoiceTokenizerEncoderOutput +from .modular_vibevoice_diffusion_head import VibeVoiceDiffusionHead +from vibevoice.schedule.dpm_solver import DPMSolverMultistepScheduler + +from .configuration_vibevoice import VibeVoiceConfig + +from .modular_vibevoice_text_tokenizer import VibeVoiceTextTokenizer, VibeVoiceTextTokenizerFast + +from .modeling_vibevoice import VibeVoiceModel, VibeVoicePreTrainedModel +from .streamer import AudioStreamer, AsyncAudioStreamer + +logger = logging.get_logger(__name__) + +if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: + modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", "rowwise"] + +@dataclass +class VibeVoiceCausalLMOutputWithPast(BaseModelOutputWithPast): + logits: Optional[torch.FloatTensor] = None + +@dataclass +class VibeVoiceGenerationOutput(ModelOutput): + """ + Output type for VibeVoice generation. + + Args: + sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The generated sequences. + speech_outputs (`List[torch.FloatTensor]`, *optional*): + List of generated speech waveforms or latents for each speech segment. + """ + sequences: torch.LongTensor = None + speech_outputs: Optional[List[torch.FloatTensor]] = None + reach_max_step_sample: Optional[torch.BoolTensor] = None + +class VibeVoiceTokenConstraintProcessor(LogitsProcessor): + """Constrains token generation to only valid tokens during speech generation.""" + + def __init__(self, valid_token_ids: List[int], device: torch.device = None): + self.valid_token_ids = torch.tensor(valid_token_ids, dtype=torch.long, device=device) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + # Create a mask for valid tokens + mask = torch.full_like(scores, float('-inf')) + mask[:, self.valid_token_ids] = 0 + + # Apply mask to scores + scores = scores + mask + return scores + +class VibeVoiceForConditionalGenerationInference(VibeVoicePreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + + # Initialize the base model + self.model = VibeVoiceModel(config) + + # LM head for text generation + self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.decoder_config.vocab_size, bias=False) + + # inference configuration + self.ddpm_inference_steps = config.diffusion_head_config.ddpm_num_inference_steps + + # Initialize weights and apply final processing + self.post_init() + + @property + def noise_scheduler(self): + return self.model.noise_scheduler + + @property + def prediction_head(self): + return self.model.prediction_head + + @property + def speech_scaling_factor(self): + return self.model.speech_scaling_factor + + @property + def speech_bias_factor(self): + return self.model.speech_bias_factor + + @property + def acoustic_tokenizer(self): + return self.model.acoustic_tokenizer + + @property + def semantic_tokenizer(self): + return self.model.semantic_tokenizer + + @property + def acoustic_connector(self): + return self.model.acoustic_connector + + @property + def semantic_connector(self): + return self.model.semantic_connector + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + """ + # Tie lm_head.weight to language_model.embed_tokens.weight + if not getattr(self.config, 'tie_word_embeddings', False): + return + + if hasattr(self, 'lm_head') and hasattr(self.model.language_model, 'embed_tokens'): + self.lm_head.weight = self.model.language_model.embed_tokens.weight + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_speech_tokenizers(self, acoustic_tokenizer=None, semantic_tokenizer=None): + """Set the speech tokenizers used for encoding and decoding speech.""" + self.model.set_speech_tokenizers(acoustic_tokenizer, semantic_tokenizer) + + def set_ddpm_inference_steps(self, num_steps=None): + self.ddpm_inference_steps = num_steps or self.config.diffusion_head_config.ddpm_num_inference_steps + + def _process_speech_inputs(self, speech_tensors, speech_masks, speech_type="audio"): + """Process speech inputs through tokenizers and connectors.""" + with torch.no_grad(): + if speech_type == "audio": + # Encode audio to acoustic latents + encoder_output = self.model.acoustic_tokenizer.encode(speech_tensors.unsqueeze(1)) + acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0] + + # Apply scaling and bias + acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device) + + # Connect to language model space + acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()] + + return acoustic_features, acoustic_connected + elif speech_type == "pt": + encoder_output = VibeVoiceTokenizerEncoderOutput(mean=speech_tensors, std=self.acoustic_tokenizer.config.fix_std) + acoustic_latents = encoder_output.sample(dist_type=self.model.acoustic_tokenizer.std_dist_type)[0] + + # Apply scaling and bias + acoustic_features = (acoustic_latents + self.model.speech_bias_factor.to(acoustic_latents.device)) * self.model.speech_scaling_factor.to(acoustic_latents.device) + + # Connect to language model space + acoustic_connected = self.model.acoustic_connector(acoustic_features)[speech_masks.cpu()] + + return acoustic_features, acoustic_connected + else: + raise NotImplementedError(f"Speech type {speech_type} not implemented") + + # @can_return_tuple + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_input_mask: Optional[torch.BoolTensor] = None, + logits_to_keep: Union[int, slice] = 0, + **kwargs, + ) -> Union[Tuple, VibeVoiceCausalLMOutputWithPast]: + """ + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + speech_tensors (`torch.FloatTensor`, *optional*): + Input speech waveforms for voice cloning or speech understanding. + speech_masks (`torch.BoolTensor`, *optional*): + Masks indicating valid speech frames. + speech_input_mask (`torch.BoolTensor`, *optional*): + Positions in the input sequence where speech embeddings should be inserted. + + Returns: + `VibeVoiceCausalLMOutputWithPast` or tuple + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Get embeddings + if inputs_embeds is None: + inputs_embeds = self.model.get_input_embeddings()(input_ids) + + # Process speech inputs if provided + if speech_tensors is not None and speech_masks is not None: + acoustic_features, speech_embeds = self._process_speech_inputs(speech_tensors.to(self.dtype), speech_masks) + if speech_input_mask is not None: + inputs_embeds[speech_input_mask] = speech_embeds + + outputs = self.model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + if labels is not None: + raise NotImplementedError("Loss computation is not implemented in this version.") + + return VibeVoiceCausalLMOutputWithPast( + logits=logits, + past_key_values=outputs.past_key_values, + last_hidden_state=hidden_states, + attentions=outputs.attentions, + ) + + def _build_generate_config_model_kwargs(self, generation_config, inputs, tokenizer, return_processors=False, **kwargs): + if generation_config is None: + generation_config = GenerationConfig( + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id = tokenizer.pad_token_id + ) + else: + generation_config = GenerationConfig( + **generation_config, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id = tokenizer.pad_token_id + ) + + generation_config, model_kwargs = self._prepare_generation_config( + generation_config, + True, + speech_start_id=tokenizer.speech_start_id, + speech_end_id=tokenizer.speech_end_id, + speech_diffusion_id=tokenizer.speech_diffusion_id, + **kwargs + ) + generation_config.speech_start_id = tokenizer.speech_start_id + generation_config.speech_end_id = tokenizer.speech_end_id + generation_config.speech_diffusion_id = tokenizer.speech_diffusion_id + + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs) + batch_size = inputs_tensor.shape[0] + device = self.device + + self._prepare_special_tokens(generation_config, True, device=device) + generation_config.use_cache = True + model_kwargs["use_cache"] = generation_config.use_cache + input_ids = inputs_tensor.to(self.device) + + input_ids_length = input_ids.shape[1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None + generation_config = self._prepare_generated_length( + generation_config=generation_config, + has_default_max_length=has_default_max_length, + has_default_min_length=has_default_min_length, + model_input_name=model_input_name, + inputs_tensor=inputs_tensor, + input_ids_length=input_ids_length, + ) + + max_cache_length = generation_config.max_length - 1 + self._prepare_cache_for_generation(generation_config, model_kwargs, None, batch_size, max_cache_length, device) + model_kwargs['cache_position'] = torch.arange(input_ids_length, device=device, dtype=torch.long) + for k, v in model_kwargs.items(): + if isinstance(v, torch.Tensor): + model_kwargs[k] = v.to(device=device) + + if return_processors: + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + encoder_input_ids=inputs_tensor, + prefix_allowed_tokens_fn=None, + logits_processor=LogitsProcessorList(), + device=inputs_tensor.device, + model_kwargs=model_kwargs, + ) + + stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=StoppingCriteriaList()) + + return generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria + else: + return generation_config, model_kwargs, input_ids + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + audio_streamer: Optional[Union[AudioStreamer, AsyncAudioStreamer]] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + speech_tensors: Optional[torch.FloatTensor] = None, + speech_masks: Optional[torch.BoolTensor] = None, + speech_input_mask: Optional[torch.BoolTensor] = None, + return_speech: bool = True, + cfg_scale: float = 1.0, + stop_check_fn: Optional[Callable[[], bool]] = None, + **kwargs, + ) -> Union[torch.LongTensor, VibeVoiceGenerationOutput]: + """ + Generates sequences of token ids and optionally speech outputs. + + Args: + All standard generation arguments from GenerationMixin + negative_prompt_ids: Negative prompt for CFG in speech generation + negative_prompt_attention_mask: Attention mask for negative prompt + speech_tensors: Input speech for voice cloning + speech_masks: Masks for speech tensors + speech_input_mask: Positions to insert speech embeddings + return_speech: Whether to decode and return speech outputs + cfg_scale: CFG scale for speech generation + stop_check_fn: Optional callable that returns True if generation should stop + + Returns: + Generated token sequences and optionally speech outputs + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria + parsed_scripts = kwargs.pop("parsed_scripts", None) + all_speakers_list = kwargs.pop("all_speakers_list", None) + max_length_times = kwargs.pop("max_length_times", 2) + + if kwargs.get('max_new_tokens', None) is None: + kwargs['max_new_tokens'] = self.config.decoder_config.max_position_embeddings - kwargs['input_ids'].shape[-1] + + generation_config, model_kwargs, input_ids, logits_processor, stopping_criteria = self._build_generate_config_model_kwargs( + generation_config, inputs, tokenizer, return_processors=True, **kwargs + ) + + negative_kwargs = { + 'input_ids': torch.full((kwargs['input_ids'].shape[0], 1), tokenizer.speech_start_id, dtype=torch.long, device=kwargs['input_ids'].device), + 'attention_mask': torch.ones((kwargs['input_ids'].shape[0], 1), dtype=torch.long, device=kwargs['input_ids'].device), + 'max_new_tokens': kwargs.get('max_new_tokens', 100) + } + negative_generation_config, negative_model_kwargs, negative_input_ids = self._build_generate_config_model_kwargs( + None, None, tokenizer, return_processors=False, **negative_kwargs + ) + + acoustic_cache = VibeVoiceTokenizerStreamingCache() + semantic_cache = VibeVoiceTokenizerStreamingCache() + + batch_size = input_ids.shape[0] + device = input_ids.device + finished_tags = torch.zeros(batch_size, dtype=torch.bool, device=device) + correct_cnt = torch.zeros(batch_size, dtype=torch.long, device=device) + is_prefill = True + inputs_embeds = None + verbose = kwargs.get("verbose", False) + + # Initialize audio chunks storage for each sample + audio_chunks = [[] for _ in range(batch_size)] + + initial_length = input_ids.shape[-1] + initial_length_per_sample = model_kwargs['attention_mask'].sum(dim=-1) + + # Define all valid tokens that can be generated + valid_tokens = [ + generation_config.speech_start_id, + generation_config.speech_end_id, + generation_config.speech_diffusion_id, + generation_config.eos_token_id + ] + # Add bos_token_id if it exists + if hasattr(generation_config, 'bos_token_id') and generation_config.bos_token_id is not None: + valid_tokens.append(generation_config.bos_token_id) + + # Add custom processor to constrain token generation + token_constraint_processor = VibeVoiceTokenConstraintProcessor(valid_tokens, device=device) + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(token_constraint_processor) + + max_steps = min(generation_config.max_length - initial_length, int(max_length_times * initial_length)) + max_step_per_sample = torch.min(generation_config.max_length - initial_length_per_sample, (max_length_times * initial_length_per_sample).long()) + reach_max_step_sample = torch.zeros(batch_size, dtype=torch.bool, device=device) + + # Create progress iterator if verbose + if kwargs.get("show_progress_bar", True): + progress_bar = tqdm(range(max_steps), desc="Generating", leave=False) + else: + progress_bar = range(max_steps) + + for step in progress_bar: + # Check for external stop signal + if stop_check_fn is not None and stop_check_fn(): + if verbose: + print(f"Generation stopped externally at step {step + 1}") + # End the audio streamer if it exists + if audio_streamer is not None: + audio_streamer.end() + break + + # Check if audio_streamer has been ended (stopped externally) + if audio_streamer is not None and hasattr(audio_streamer, 'finished_flags'): + if any(audio_streamer.finished_flags): + if verbose: + print(f"Audio generation stopped externally at step {step + 1}") + break + + if finished_tags.all(): + if hasattr(progress_bar, 'set_description'): + progress_bar.set_description("Generation complete") + break + + if input_ids.shape[-1] >= generation_config.max_length: + print(f"Reached maximum generation length {generation_config.max_length}, stopped it.") + reached_samples = torch.arange(batch_size, device=device)[~finished_tags] + if reached_samples.numel() > 0: + reach_max_step_sample[reached_samples] = True + break + + # Update progress bar description with active samples + if hasattr(progress_bar, 'set_description'): + active_samples = (~finished_tags).sum().item() + progress_bar.set_description(f"Generating (active: {active_samples}/{batch_size})") + + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + if is_prefill: + # we process the speech inputs only during the first generation step + prefill_inputs = { + "speech_tensors": speech_tensors.to(device=device), + "speech_masks": speech_masks.to(device), + "speech_input_mask": speech_input_mask.to(device), + } + is_prefill = False + else: + _ = model_inputs.pop('inputs_embeds', None) + prefill_inputs = {'inputs_embeds': inputs_embeds} + + # Forward pass through the model + outputs = self( + **model_inputs, **prefill_inputs, logits_to_keep=1, return_dict=True, output_attentions=False, output_hidden_states=False, + ) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False, + ) + + # Get logits and apply logits processor + next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) + # next_token_logits = outputs.logits[:, -1, :].to(copy=True, device=input_ids.device) + next_token_scores = logits_processor(input_ids, next_token_logits) + + # token selection + if generation_config.do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + next_tokens[finished_tags] = generation_config.eos_token_id + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if not kwargs.get('refresh_negative', True): + negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs) + # Forward negative pass through the model + if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None: + negative_model_inputs['inputs_embeds'] = inputs_embeds + negative_model_inputs['input_ids'] = None + + negative_outputs = self( + **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False, + ) + negative_model_kwargs = self._update_model_kwargs_for_generation( + negative_outputs, negative_model_kwargs, is_encoder_decoder=False, + ) + negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1) + + # reached end of generation + if (next_tokens == generation_config.eos_token_id).any(): + eos_indices = (next_tokens == generation_config.eos_token_id).nonzero(as_tuple=False).squeeze(1) + # Only print for samples that are newly finished (not already marked as finished) + new_eos_indices = eos_indices[~finished_tags[eos_indices]] + if new_eos_indices.numel() > 0: + finished_tags[new_eos_indices] = True + if verbose: + print(f"Samples {new_eos_indices.tolist()} reached EOS token at step {step + 1}.", flush=True) + if audio_streamer is not None: + audio_streamer.end(new_eos_indices) + + # Check if any sample reached its maximum generation length + max_length_reached = step >= max_step_per_sample + new_max_length_indices = torch.nonzero(max_length_reached & ~finished_tags, as_tuple=False).squeeze(1) + if new_max_length_indices.numel() > 0: + finished_tags[new_max_length_indices] = True + reach_max_step_sample[new_max_length_indices] = True + if verbose: + print(f"Samples {new_max_length_indices.tolist()} reached max generation length at step {step + 1}.", flush=True) + if audio_streamer is not None: + audio_streamer.end(new_max_length_indices) + + # speech_end + diffusion_end_indices = (next_tokens == generation_config.speech_end_id).nonzero(as_tuple=False).squeeze(1) + if diffusion_end_indices.numel() > 0: + # Clear tokenizer caches for samples that reached speech end + acoustic_cache.set_to_zero(diffusion_end_indices) + semantic_cache.set_to_zero(diffusion_end_indices) + + # speech_begin + diffusion_start_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_start_id)] + if diffusion_start_indices.numel() > 0 and kwargs.get('refresh_negative', True): + # update attention mask + for i, sample_idx in enumerate(diffusion_start_indices.tolist()): + negative_model_kwargs['attention_mask'][sample_idx, :] = 0 + negative_model_kwargs['attention_mask'][sample_idx, -1] = 1 + # update past key values + for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache, + negative_model_kwargs['past_key_values'].value_cache)): + # Process each non-diffusion sample + for sample_idx in diffusion_start_indices.tolist(): + # Shift cache for this sample + k_cache[sample_idx, :, -1, :] = k_cache[sample_idx, :, 0, :].clone() + v_cache[sample_idx, :, -1, :] = v_cache[sample_idx, :, 0, :].clone() + # update negative_input_ids + for sample_idx in diffusion_start_indices.tolist(): + negative_input_ids[sample_idx, -1] = generation_config.speech_start_id + + # Prepare inputs_embeds for next iteration + # Initialize with default embeddings for all tokens + next_inputs_embeds = self.model.get_input_embeddings()(next_tokens).unsqueeze(1) # [batch_size, 1, hidden_size] + + # forward diffusion + # Diffusion indices are those that are not finished and not special tokens + diffusion_indices = torch.arange(batch_size, device=device)[~finished_tags & (next_tokens == generation_config.speech_diffusion_id)] + + if diffusion_indices.numel() > 0: + if kwargs.get('refresh_negative', True): + negative_model_inputs = self.prepare_inputs_for_generation(negative_input_ids, **negative_model_kwargs) + # Forward negative pass through the model + if negative_model_inputs['inputs_embeds'] is None and inputs_embeds is not None: + negative_model_inputs['inputs_embeds'] = inputs_embeds + negative_model_inputs['input_ids'] = None + + negative_outputs = self( + **negative_model_inputs, logits_to_keep=0, return_dict=True, output_attentions=False, output_hidden_states=False, + ) + negative_model_kwargs = self._update_model_kwargs_for_generation( + negative_outputs, negative_model_kwargs, is_encoder_decoder=False, + ) + negative_input_ids = torch.cat([negative_input_ids, next_tokens[:, None]], dim=-1) + # correct the non-diffusion indices + # we forward all samples' negative outputs even if + # they are not in diffusion mode to keep the cache consistent + # So we need to correct the kv cache of non-diffusion samples + non_diffusion_mask = ~finished_tags & (next_tokens != generation_config.speech_diffusion_id) + if non_diffusion_mask.any(): + non_diffusion_indices = torch.arange(batch_size, device=device)[non_diffusion_mask] + start_indices = correct_cnt[non_diffusion_indices] + + # 1. Update attention_mask - need to handle each sample separately + seq_len = negative_model_kwargs['attention_mask'].shape[1] + for i, (sample_idx, start_idx) in enumerate(zip(non_diffusion_indices.tolist(), start_indices.tolist())): + # Shift the attention mask for this sample + if start_idx + 1 < seq_len - 1: + negative_model_kwargs['attention_mask'][sample_idx, start_idx+1:] = \ + negative_model_kwargs['attention_mask'][sample_idx, start_idx:-1].clone() + negative_model_kwargs['attention_mask'][sample_idx, start_idx] = 0 + + # 2. Update past_key_values + for layer_idx, (k_cache, v_cache) in enumerate(zip(negative_model_kwargs['past_key_values'].key_cache, + negative_model_kwargs['past_key_values'].value_cache)): + # Process each non-diffusion sample + for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()): + if start_idx + 1 < k_cache.shape[2] - 1: + # Shift cache for this sample + k_cache[sample_idx, :, start_idx+1:, :] = k_cache[sample_idx, :, start_idx:-1, :].clone() + v_cache[sample_idx, :, start_idx+1:, :] = v_cache[sample_idx, :, start_idx:-1, :].clone() + + # 3. Update negative_input_ids + for sample_idx, start_idx in zip(non_diffusion_indices.tolist(), start_indices.tolist()): + if start_idx + 1 < negative_input_ids.shape[1] - 1: + negative_input_ids[sample_idx, start_idx+1:] = \ + negative_input_ids[sample_idx, start_idx:-1].clone() + + correct_cnt[non_diffusion_indices] += 1 + + positive_condition = outputs.last_hidden_state[diffusion_indices, -1, :] + negative_condition = negative_outputs.last_hidden_state[diffusion_indices, -1, :] + + speech_latent = self.sample_speech_tokens( + positive_condition, + negative_condition, + cfg_scale=cfg_scale, + ).unsqueeze(1) + + # Decode acoustic latent to audio using acoustic streaming cache + scaled_latent = speech_latent / self.model.speech_scaling_factor.to(speech_latent.device) - self.model.speech_bias_factor.to(speech_latent.device) + audio_chunk = self.model.acoustic_tokenizer.decode( + scaled_latent.to(self.model.acoustic_tokenizer.device), + cache=acoustic_cache, # Use acoustic-specific cache + sample_indices=diffusion_indices.to(self.model.acoustic_tokenizer.device), + use_cache=True, + debug=False + ) + + # Store audio chunks for each sample + for i, sample_idx in enumerate(diffusion_indices): + idx = sample_idx.item() + # Only append audio chunk if the sample is not finished + if not finished_tags[idx]: + audio_chunks[idx].append(audio_chunk[i]) + + # Add streaming support here + if audio_streamer is not None: + # Stream the audio chunks immediately + audio_streamer.put(audio_chunk, diffusion_indices) + + # Encode audio to semantic features using semantic streaming cache + semantic_features = self.model.semantic_tokenizer.encode( + audio_chunk, + cache=semantic_cache, # Use semantic-specific cache + sample_indices=diffusion_indices, + use_cache=True, + debug=False + ).mean # semantic tokenizer has no VAE. + + # Combine acoustic and semantic features for next input + acoustic_embed = self.model.acoustic_connector(speech_latent) + semantic_embed = self.model.semantic_connector(semantic_features) + diffusion_embeds = acoustic_embed + semantic_embed + + # Update embeddings for diffusion indices + next_inputs_embeds[diffusion_indices] = diffusion_embeds + + # Set inputs_embeds for next iteration + inputs_embeds = next_inputs_embeds + + if audio_streamer is not None: + audio_streamer.end() + + # Concatenate audio chunks for each sample + final_audio_outputs = [] + for sample_chunks in audio_chunks: + if sample_chunks: + # Concatenate all chunks along the time dimension (assumed to be the last dimension) + concatenated_audio = torch.cat(sample_chunks, dim=-1) + final_audio_outputs.append(concatenated_audio) + else: + # If no audio was generated for this sample, append None + final_audio_outputs.append(None) + + return VibeVoiceGenerationOutput( + sequences=input_ids, + speech_outputs=final_audio_outputs if return_speech else None, + reach_max_step_sample=reach_max_step_sample, + ) + + @torch.no_grad() + def sample_speech_tokens(self, condition, neg_condition, cfg_scale=3.0): + self.model.noise_scheduler.set_timesteps(self.ddpm_inference_steps) + condition = torch.cat([condition, neg_condition], dim=0).to(self.model.prediction_head.device) + speech = torch.randn(condition.shape[0], self.config.acoustic_vae_dim).to(condition) + for t in self.model.noise_scheduler.timesteps: + half = speech[: len(speech) // 2] + combined = torch.cat([half, half], dim=0) + eps = self.model.prediction_head(combined, t.repeat(combined.shape[0]).to(combined), condition=condition) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + speech = self.model.noise_scheduler.step(eps, t, speech).prev_sample + return speech[: len(speech) // 2] + + +AutoModelForCausalLM.register(VibeVoiceConfig, VibeVoiceForConditionalGenerationInference) + +__all__ = [ + "VibeVoiceForConditionalGenerationInference", +] diff --git a/vibevoice/modular/modular_vibevoice_diffusion_head.py b/vibevoice/modular/modular_vibevoice_diffusion_head.py new file mode 100644 index 0000000..59de50f --- /dev/null +++ b/vibevoice/modular/modular_vibevoice_diffusion_head.py @@ -0,0 +1,287 @@ +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.auto import AutoModel +from transformers.modeling_utils import PreTrainedModel +# from transformers.modeling_layers import GradientCheckpointingLayer +from transformers.activations import ACT2FN +from transformers.utils import logging + +from .configuration_vibevoice import VibeVoiceDiffusionHeadConfig + + +logger = logging.get_logger(__name__) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, memory_efficient=False): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' + +def modulate(x, shift, scale): + """Apply modulation to input tensor.""" + return x * (1 + scale) + shift + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + + Args: + hidden_size (`int`): Size of the output embedding + frequency_embedding_size (`int`, optional): Size of the intermediate frequency embedding + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=False), + # nn.SiLU(), + ACT2FN['silu'], + nn.Linear(hidden_size, hidden_size, bias=False), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + Args: + t (`torch.Tensor`): A 1-D Tensor of N indices, one per batch element. + These may be fractional. + dim (`int`): The dimension of the output. + max_period (`int`, optional): Controls the minimum frequency of the embeddings. + + Returns: + `torch.Tensor`: An [N, D] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding.to(t.dtype) + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class FeedForwardNetwork(nn.Module): + """ + Standard feed-forward network with SwiGLU activation. + + Args: + embed_dim (`int`): Input dimension + ffn_dim (`int`): Hidden dimension + """ + def __init__( + self, + embed_dim, + ffn_dim, + ): + super().__init__() + self.embed_dim = embed_dim + self.gate_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.up_proj = nn.Linear(self.embed_dim, ffn_dim, bias=False) + self.down_proj = nn.Linear(ffn_dim, self.embed_dim, bias=False) + self.act_fn = ACT2FN['silu'] # Using SiLU as the activation function + + def forward(self, x): + gate = self.gate_proj(x) + up = self.up_proj(x) + + # SwiGLU activation + # gate = F.silu(gate) + gate = self.act_fn(gate) + return self.down_proj(gate * up) + + +class HeadLayer(nn.Module): + """ + A layer in the diffusion head. + + Args: + embed_dim (`int`): Input dimension + ffn_dim (`int`): Hidden dimension + cond_dim (`int`): Condition embedding dimension + norm_eps (`float`, optional): Epsilon for normalization + """ + def __init__( + self, + embed_dim, + ffn_dim, + cond_dim, + norm_eps=1e-5, + ): + super().__init__() + self.embed_dim = embed_dim + self.cond_dim = cond_dim + self.ffn_dim = ffn_dim + self.ffn = FeedForwardNetwork( + self.embed_dim, + self.ffn_dim, + ) + self.norm = RMSNorm(self.embed_dim, eps=norm_eps) + self.adaLN_modulation = nn.Sequential( + # nn.SiLU(), + ACT2FN['silu'], + nn.Linear(cond_dim, 3 * self.embed_dim, bias=False) + ) + + def forward(self, x, c): + shift_ffn, scale_ffn, gate_ffn = self.adaLN_modulation(c).chunk(3, dim=-1) + x = x + gate_ffn * self.ffn(modulate(self.norm(x), shift_ffn, scale_ffn)) + return x + + +class FinalLayer(nn.Module): + """ + Final layer in the diffusion head. + + Args: + hidden_size (`int`): Input dimension + output_size (`int`): Output dimension + cond_size (`int`): Condition embedding dimension + norm_eps (`float`, optional): Epsilon for normalization + """ + def __init__(self, hidden_size, output_size, cond_size, norm_eps=1e-5): + super().__init__() + self.norm_final = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=False) + self.linear = nn.Linear(hidden_size, output_size, bias=False) + self.adaLN_modulation = nn.Sequential( + # nn.SiLU(), + ACT2FN['silu'], + nn.Linear(cond_size, 2 * hidden_size, bias=False) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class VibeVoiceDiffusionHead(PreTrainedModel): + """ + Diffusion head model for vibevoice. + + Args: + config (`VibeVoiceDiffusionHeadConfig`): Model configuration + latent_size (`int`, optional): Size of the latent space. If not provided, uses `config.latent_size`. + """ + config_class = VibeVoiceDiffusionHeadConfig + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def __init__( + self, + config, + ): + super().__init__(config) + self.config = config + self.cond_dim = config.hidden_size + latent_size = config.latent_size + + self.noisy_images_proj = nn.Linear(latent_size, config.hidden_size, bias=False) + self.cond_proj = nn.Linear(config.hidden_size, self.cond_dim, bias=False) + self.t_embedder = TimestepEmbedder(self.cond_dim) + + ffn_dim = int(config.hidden_size * config.head_ffn_ratio) + + # Create the intermediate layers + self.layers = nn.ModuleList([ + HeadLayer( + embed_dim=config.hidden_size, + ffn_dim=ffn_dim, + cond_dim=self.cond_dim, + norm_eps=config.rms_norm_eps + ) + for _ in range(config.head_layers) + ]) + + # Final layer for output + self.final_layer = FinalLayer( + hidden_size=config.hidden_size, + output_size=latent_size, + cond_size=self.cond_dim, + norm_eps=config.rms_norm_eps + ) + + self.initialize_weights() + + def initialize_weights(self): + """Initialize the weights of the model.""" + # Initialize timestep embedder + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers + for layer in self.layers: + nn.init.constant_(layer.adaLN_modulation[-1].weight, 0) + + # Zero-out output layers + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + + def forward( + self, + noisy_images, + timesteps, + condition, + ): + """ + Forward pass of the prediction head. + + Args: + noisy_images (`torch.Tensor`): Noisy images/latents to denoise + timesteps (`torch.Tensor`): Timesteps for diffusion + condition (`torch.Tensor`): Conditioning information + + Returns: + `torch.Tensor`: The predicted noise/velocity + """ + x = self.noisy_images_proj(noisy_images) + t = self.t_embedder(timesteps) + condition = self.cond_proj(condition) + c = condition + t + + for layer in self.layers: + x = layer(x, c) + + x = self.final_layer(x, c) + return x + + +AutoModel.register(VibeVoiceDiffusionHeadConfig, VibeVoiceDiffusionHead) + +__all__ = [ + "VibeVoiceDiffusionHead", +] \ No newline at end of file diff --git a/vibevoice/modular/modular_vibevoice_text_tokenizer.py b/vibevoice/modular/modular_vibevoice_text_tokenizer.py new file mode 100644 index 0000000..bfa7bdd --- /dev/null +++ b/vibevoice/modular/modular_vibevoice_text_tokenizer.py @@ -0,0 +1,214 @@ +"""Tokenization classes for vibevoice.""" + +from typing import List, Optional, Union + +from transformers.utils import logging +from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer +from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast + +logger = logging.get_logger(__name__) + + +class VibeVoiceTextTokenizer(Qwen2Tokenizer): + """ + Construct a VibeVoice tokenizer. Based on the Qwen2 tokenizer with additional special tokens for speech. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + merges_file (`str`): + Path to the merges file. + errors (`str`, *optional*, defaults to `"replace"`): + Paradigm to follow when decoding bytes to UTF-8. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. + bos_token (`str`, *optional*): + The beginning of sequence token. Not used for vibevoice. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding. + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to add special tokens when encoding. + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + merges_file, + errors="replace", + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + add_prefix_space=False, + add_special_tokens=True, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + errors=errors, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + add_special_tokens=add_special_tokens, + **kwargs, + ) + + # Add VibeVoice-specific special tokens + self._add_vibevoice_special_tokens() + + def _add_vibevoice_special_tokens(self): + """Add VibeVoice-specific special tokens.""" + special_tokens = { + "additional_special_tokens": [ + "<|vision_start|>", # Speech start (reusing vision tokens) + "<|vision_end|>", # Speech end + "<|vision_pad|>", # Speech diffusion pad + ] + } + num_added = self.add_special_tokens(special_tokens) + + # Cache special token IDs + self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") + self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") + self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") + + self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') + + return num_added + + @property + def eos_id(self) -> int: + """Id of the end of sequence token.""" + return self._eos_id + + @property + def speech_start_id(self) -> int: + """Id of the speech start token.""" + return self._speech_start_id + + @property + def speech_end_id(self) -> int: + """Id of the speech end token.""" + return self._speech_end_id + + @property + def speech_diffusion_id(self) -> int: + """Id of the speech diffusion token.""" + return self._speech_diffusion_id + + @property + def pad_id(self) -> int: + """Id used for padding (returns -100 for loss masking).""" + return -100 + + +class VibeVoiceTextTokenizerFast(Qwen2TokenizerFast): + """ + Construct a "fast" VibeVoice tokenizer (backed by HuggingFace's *tokenizers* library). + Based on the Qwen2 tokenizer with additional special tokens for speech. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + merges_file (`str`, *optional*): + Path to the merges file. + tokenizer_file (`str`, *optional*): + Path to [tokenizers](https://github.com/huggingface/tokenizers) file. + unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The unknown token. + bos_token (`str`, *optional*): + The beginning of sequence token. Not used for vibevoice. + eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The end of sequence token. + pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`): + The token used for padding. + """ + + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + merges_file=None, + tokenizer_file=None, + unk_token="<|endoftext|>", + bos_token=None, + eos_token="<|endoftext|>", + pad_token="<|endoftext|>", + add_prefix_space=False, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + merges_file=merges_file, + tokenizer_file=tokenizer_file, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + pad_token=pad_token, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + # Add VibeVoice-specific special tokens + self._add_vibevoice_special_tokens() + + def _add_vibevoice_special_tokens(self): + """Add VibeVoice-specific special tokens.""" + special_tokens = { + "additional_special_tokens": [ + "<|vision_start|>", # Speech start (reusing vision tokens) + "<|vision_end|>", # Speech end + "<|vision_pad|>", # Speech diffusion pad + ] + } + num_added = self.add_special_tokens(special_tokens) + + # Cache special token IDs + self._speech_start_id = self.convert_tokens_to_ids("<|vision_start|>") + self._speech_end_id = self.convert_tokens_to_ids("<|vision_end|>") + self._speech_diffusion_id = self.convert_tokens_to_ids("<|vision_pad|>") + + # self._eos_id = self.convert_tokens_to_ids('<|endoftext|>') + self._eos_id = self.eos_token_id # qwen2 / qwen3 + self._pad_id = self.convert_tokens_to_ids('<|image_pad|>') + + return num_added + + @property + def eos_id(self) -> int: + """Id of the end of sequence token.""" + return self._eos_id + + @property + def speech_start_id(self) -> int: + """Id of the speech start token.""" + return self._speech_start_id + + @property + def speech_end_id(self) -> int: + """Id of the speech end token.""" + return self._speech_end_id + + @property + def speech_diffusion_id(self) -> int: + """Id of the speech diffusion token.""" + return self._speech_diffusion_id + + @property + def pad_id(self) -> int: + """Id used for padding (returns -100 for loss masking).""" + return self._pad_id + + +__all__ = [ + "VibeVoiceTextTokenizer", + "VibeVoiceTextTokenizerFast", +] \ No newline at end of file diff --git a/vibevoice/modular/modular_vibevoice_tokenizer.py b/vibevoice/modular/modular_vibevoice_tokenizer.py new file mode 100644 index 0000000..fbd5182 --- /dev/null +++ b/vibevoice/modular/modular_vibevoice_tokenizer.py @@ -0,0 +1,1195 @@ +import math +import typing as tp +from functools import partial +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple, Union +import copy + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.auto import AutoModel + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from transformers.modeling_utils import PreTrainedModel +from transformers.activations import ACT2FN + +from .configuration_vibevoice import VibeVoiceAcousticTokenizerConfig, VibeVoiceSemanticTokenizerConfig + +logger = logging.get_logger(__name__) + +import os +# Try to import APEX FusedRMSNorm +try: + from apex.normalization.fused_layer_norm import fused_rms_norm_affine + APEX_AVAILABLE = True + logger.info("APEX FusedRMSNorm is available and will be used for optimization") + if int(os.getenv("OPTIMIZE_FOR_SPEED", "0")) == 0: + APEX_AVAILABLE = False + logger.warning("APEX FusedRMSNorm is disabled by environment variable OPTIMIZE_FOR_SPEED=0") +except ImportError: + APEX_AVAILABLE = False + logger.warning("APEX FusedRMSNorm not available, using native implementation") +# APEX_AVAILABLE=False + +# Normalization modules +class ConvLayerNorm(nn.LayerNorm): + """ + Convolution-friendly LayerNorm that moves channels to last dimensions + before running the normalization and moves them back to original position right after. + """ + def __init__(self, normalized_shape: tp.Union[int, tp.List[int], torch.Size], **kwargs): + super().__init__(normalized_shape, **kwargs) + + def forward(self, x): + x = x.transpose(1, 2) # b ... t -> b t ... + x = nn.functional.layer_norm(x.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(x) + x = x.transpose(1, 2) # b t ... -> b ... t + return x + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): + super().__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + weight_shape = (dim,) if weight_shape is None else weight_shape + self.weight = nn.Parameter(torch.ones(weight_shape)) + else: + self.register_parameter('weight', None) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + return output + + def extra_repr(self) -> str: + return f'dim={self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' + +class ConvRMSNorm(RMSNorm): + def __init__(self, dim: int, eps: float = 1e-5, elementwise_affine=True, weight_shape=None): + super().__init__(dim, eps, elementwise_affine, weight_shape) + + def forward(self, x): + x = x.transpose(1, 2) # b ... t -> b t ... + if (not APEX_AVAILABLE) or (not self.elementwise_affine): + # Fallback to native implementation + output = self._norm(x.float()).type_as(x) + if self.weight is not None: + output = output * self.weight + else: + output = fused_rms_norm_affine(x, self.weight, self.weight.shape, self.eps) + output = output.transpose(1, 2) # b t ... -> b ... t + return output + +# Convolutional layers and utilities +CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm', + 'time_layer_norm', 'layer_norm', 'time_group_norm']) + + +def apply_parametrization_norm(module: nn.Module, norm: str = 'none') -> nn.Module: + assert norm in CONV_NORMALIZATIONS + if norm == 'weight_norm': + return nn.utils.weight_norm(module) + elif norm == 'spectral_norm': + return nn.utils.spectral_norm(module) + else: + # We already check was in CONV_NORMALIZATION, so any other choice + # doesn't need reparametrization. + return module + + +def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs) -> nn.Module: + """Return the proper normalization module. If causal is True, this will ensure the returned + module is causal, or return an error if the normalization doesn't support causal evaluation. + """ + assert norm in CONV_NORMALIZATIONS + if norm == 'layer_norm': + assert isinstance(module, nn.modules.conv._ConvNd) + return ConvLayerNorm(module.out_channels, **norm_kwargs) + elif norm == 'time_group_norm': + if causal: + raise ValueError("GroupNorm doesn't support causal evaluation.") + assert isinstance(module, nn.modules.conv._ConvNd) + return nn.GroupNorm(1, module.out_channels, **norm_kwargs) + else: + return nn.Identity() + + +def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, + padding_total: int = 0) -> int: + """Calculate extra padding needed for convolution to have the same output length""" + length = x.shape[-1] + n_frames = (length - kernel_size + padding_total) / stride + 1 + ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total) + return ideal_length - length + + +def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'zero', value: float = 0.): + """Pad 1D input with handling for small inputs in reflect mode""" + length = x.shape[-1] + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + if mode == 'reflect': + max_pad = max(padding_left, padding_right) + extra_pad = 0 + if length <= max_pad: + extra_pad = max_pad - length + 1 + x = F.pad(x, (0, extra_pad)) + padded = F.pad(x, paddings, mode, value) + end = padded.shape[-1] - extra_pad + return padded[..., :end] + else: + return F.pad(x, paddings, mode, value) + + +def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]): + """Remove padding from x, handling properly zero padding. Only for 1d!""" + padding_left, padding_right = paddings + assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right) + assert (padding_left + padding_right) <= x.shape[-1] + end = x.shape[-1] - padding_right + return x[..., padding_left: end] + + +class NormConv1d(nn.Module): + """Wrapper around Conv1d and normalization applied to this conv""" + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.conv(x) + x = self.norm(x) + return x + + +class NormConvTranspose1d(nn.Module): + """Wrapper around ConvTranspose1d and normalization applied to this conv""" + def __init__(self, *args, causal: bool = False, norm: str = 'none', + norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs): + super().__init__() + self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm) + self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs) + self.norm_type = norm + + def forward(self, x): + x = self.convtr(x) + x = self.norm(x) + return x + + +class VibeVoiceTokenizerStreamingCache: + """Cache for streaming convolution, similar to KV cache in attention""" + def __init__(self): + self.cache = {} # Dict mapping (layer_id, sample_idx) to state tensor + + def get(self, layer_id: str, sample_indices: torch.Tensor) -> Optional[torch.Tensor]: + """Get cached states for given layer and sample indices""" + states = [] + max_length = 0 + + # First pass: collect states and find max length + for idx in sample_indices.tolist(): + key = (layer_id, idx) + if key not in self.cache: + return None # If any sample is missing, return None + state = self.cache[key] + states.append(state) + max_length = max(max_length, state.shape[-1]) + + # Second pass: pad states to max length if needed + if len(states) > 0 and states[0].dim() >= 2: + padded_states = [] + for state in states: + if state.shape[-1] < max_length: + # Pad on the time dimension (last dimension) + pad_size = max_length - state.shape[-1] + # Pad with zeros on the LEFT to align the most recent samples + padded_state = F.pad(state, (pad_size, 0), mode='constant', value=0) + padded_states.append(padded_state) + else: + padded_states.append(state) + return torch.stack(padded_states, dim=0) + else: + return torch.stack(states, dim=0) + + def set(self, layer_id: str, sample_indices: torch.Tensor, states: torch.Tensor): + """Set cached states for given layer and sample indices""" + for i, idx in enumerate(sample_indices.tolist()): + key = (layer_id, idx) + self.cache[key] = states[i].detach() + + def set_to_zero(self, sample_indices: torch.Tensor): + """Set all cached states to zero for given sample indices""" + for key in list(self.cache.keys()): + layer_id, sample_idx = key + if sample_idx in sample_indices.tolist(): + # Create zero tensor with same shape and dtype as cached tensor + cached_tensor = self.cache[key] + self.cache[key] = torch.zeros_like(cached_tensor) + + def clear(self, layer_id: Optional[str] = None, sample_indices: Optional[torch.Tensor] = None): + """Clear cache for specific layer/samples or everything""" + if layer_id is None and sample_indices is None: + self.cache.clear() + elif layer_id is not None and sample_indices is None: + # Clear all samples for a specific layer + keys_to_remove = [k for k in self.cache.keys() if k[0] == layer_id] + for k in keys_to_remove: + del self.cache[k] + elif layer_id is not None and sample_indices is not None: + # Clear specific samples for a specific layer + for idx in sample_indices.tolist(): + key = (layer_id, idx) + self.cache.pop(key, None) + +class SConv1d(nn.Module): + """Conv1d with built-in handling of asymmetric or causal padding and normalization.""" + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, dilation: int = 1, + groups: int = 1, bias: bool = True, causal: bool = False, + norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, + pad_mode: str = 'reflect'): + super().__init__() + self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride, + dilation=dilation, groups=groups, bias=bias, causal=causal, + norm=norm, norm_kwargs=norm_kwargs) + self.causal = causal + self.pad_mode = pad_mode + + # Store configuration + self.kernel_size = kernel_size + self.dilation = dilation + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + + # For causal convolution, we need to maintain kernel_size - 1 samples as context + # need to check use which context_size is more suitable + # self.context_size = (kernel_size - 1) * dilation + self.context_size = (kernel_size - 1) * dilation - (stride - 1) + + # For non-streaming mode, calculate padding + self.padding_total = (kernel_size - 1) * dilation - (stride - 1) + + # Create a unique layer ID for cache management + self._layer_id = None + + @property + def layer_id(self): + if self._layer_id is None: + self._layer_id = f"sconv1d_{id(self)}" + return self._layer_id + + def forward(self, x: torch.Tensor, + cache: Optional[VibeVoiceTokenizerStreamingCache] = None, + sample_indices: Optional[torch.Tensor] = None, + use_cache: bool = False, + debug: bool = False) -> torch.Tensor: + """ + Forward pass with optional streaming support via cache. + + Args: + x: Input tensor [batch_size, channels, time] + cache: VibeVoiceTokenizerStreamingCache object for maintaining states + sample_indices: Indices identifying each sample for cache management + use_cache: Whether to use cached states for streaming + debug: Whether to print debug information + + Returns: + Output tensor + """ + B, C, T = x.shape + + # Non-streaming mode + if not use_cache or cache is None: + return self._forward_non_streaming(x, debug=debug) + + # Streaming mode + assert self.causal, "Streaming mode is only supported for causal convolutions" + assert sample_indices is not None, "sample_indices must be provided for streaming mode" + assert len(sample_indices) == B, "sample_indices must match batch size" + + return self._forward_streaming(x, cache, sample_indices, debug) + + def _forward_streaming(self, x: torch.Tensor, + cache: VibeVoiceTokenizerStreamingCache, + sample_indices: torch.Tensor, + debug: bool = False) -> torch.Tensor: + """Streaming forward pass with cache operations kept separate from compiled code""" + B, C, T = x.shape + + # Cache operations (not compiled) + cached_states = cache.get(self.layer_id, sample_indices) + + if cached_states is None: + # First chunk - initialize with zeros for context + if self.context_size > 0: + cached_states = torch.zeros(B, C, self.context_size, device=x.device, dtype=x.dtype) + if debug: + print(f"[DEBUG] Initialized cache with shape: {cached_states.shape}, context_size={self.context_size}") + else: + cached_states = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) + if debug: + print(f"[DEBUG] No context needed (kernel_size=stride)") + + # Concatenate cached states with input + if cached_states.shape[2] > 0: + input_with_context = torch.cat([cached_states, x], dim=2) + else: + input_with_context = x + + if debug: + print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_states.shape}, Combined: {input_with_context.shape}") + + # Apply convolution directly - no extra padding in streaming mode + # The conv layer will handle its own padding internally + output = self.conv(input_with_context) + + if debug: + print(f"[DEBUG] Output shape: {output.shape}") + + # Update cache for next chunk + if self.context_size > 0: + # Calculate how many samples to keep + total_input_length = input_with_context.shape[2] + + # Keep the last context_size samples + if total_input_length >= self.context_size: + new_cache_start = total_input_length - self.context_size + new_cache = input_with_context[:, :, new_cache_start:] + else: + # If we have less than context_size samples, keep everything + new_cache = input_with_context + + if debug: + print(f"[DEBUG] New cache shape: {new_cache.shape}") + + cache.set(self.layer_id, sample_indices, new_cache) + + return output + + def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: + """Standard forward pass without streaming""" + B, C, T = x.shape + kernel_size = self.kernel_size + stride = self.stride + dilation = self.dilation + padding_total = self.padding_total + + # Compute extra padding for stride alignment + extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total) + + if debug: + print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}, padding_total={padding_total}, extra_padding={extra_padding}") + + if self.causal: + # Left padding for causal + if self.pad_mode == 'constant': + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode, value=0) + else: + x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode) + else: + # Symmetric padding for non-causal + padding_right = padding_total // 2 + padding_left = padding_total - padding_right + x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode) + + if debug: + print(f"[DEBUG NON-STREAMING] After padding: {x.shape}") + + output = self.conv(x) + + if debug: + print(f"[DEBUG NON-STREAMING] Output shape: {output.shape}") + + return output + + +class SConvTranspose1d(nn.Module): + """ConvTranspose1d with built-in handling of asymmetric or causal padding and normalization.""" + def __init__(self, in_channels: int, out_channels: int, + kernel_size: int, stride: int = 1, causal: bool = False, + norm: str = 'none', trim_right_ratio: float = 1., + norm_kwargs: tp.Dict[str, tp.Any] = {}, bias: bool = True): + super().__init__() + self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride, + causal=causal, norm=norm, norm_kwargs=norm_kwargs, bias=bias) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert self.causal or self.trim_right_ratio == 1., \ + "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1. + + # Store configuration + self.kernel_size = kernel_size + self.stride = stride + self.in_channels = in_channels + self.out_channels = out_channels + + # For transposed convolution, padding calculation is different + self.padding_total = kernel_size - stride + + # For streaming, we need to keep track of input history + # Transposed conv needs to see multiple input samples to produce correct output + self.context_size = kernel_size - 1 + + # Create a unique layer ID for cache management + self._layer_id = None + + @property + def layer_id(self): + if self._layer_id is None: + self._layer_id = f"sconvtr1d_{id(self)}" + return self._layer_id + + def forward(self, x: torch.Tensor, + cache: Optional[VibeVoiceTokenizerStreamingCache] = None, + sample_indices: Optional[torch.Tensor] = None, + use_cache: bool = False, + debug: bool = False) -> torch.Tensor: + """ + Forward pass with optional streaming support via cache. + """ + B, C, T = x.shape + + # Non-streaming mode + if not use_cache or cache is None: + return self._forward_non_streaming(x, debug=debug) + + # Streaming mode + assert sample_indices is not None, "sample_indices must be provided for streaming mode" + assert len(sample_indices) == B, "sample_indices must match batch size" + + return self._forward_streaming(x, cache, sample_indices, debug) + + def _forward_streaming(self, x: torch.Tensor, + cache: VibeVoiceTokenizerStreamingCache, + sample_indices: torch.Tensor, + debug: bool = False) -> torch.Tensor: + """Streaming forward pass with cache operations kept separate from compiled code""" + B, C, T = x.shape + + # Cache operations (not compiled) + cached_input = cache.get(self.layer_id, sample_indices) + + if cached_input is None: + # First chunk - no history yet + cached_input = torch.zeros(B, C, 0, device=x.device, dtype=x.dtype) + if debug: + print(f"[DEBUG] Initialized empty cache for transposed conv") + + # Concatenate cached input with new input + full_input = torch.cat([cached_input, x], dim=2) + + if debug: + print(f"[DEBUG] Input shape: {x.shape}, Cache shape: {cached_input.shape}, Combined: {full_input.shape}") + + # First chunk or debug mode - use uncompiled version + full_output = self.convtr(full_input) + + if debug: + print(f"[DEBUG] Full transposed conv output shape: {full_output.shape}") + + # Calculate padding to remove + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + padding_left = self.padding_total - padding_right + else: + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + + # Remove padding + if padding_left + padding_right > 0: + full_output = unpad1d(full_output, (padding_left, padding_right)) + + if debug: + print(f"[DEBUG] After unpadding: {full_output.shape}") + + # Determine which part of the output corresponds to the new input + if cached_input.shape[2] == 0: + # First chunk - return all output + output = full_output + else: + # Subsequent chunks - return only the new output + expected_new_output = T * self.stride + + # Take the last expected_new_output samples + if full_output.shape[2] >= expected_new_output: + output = full_output[:, :, -expected_new_output:] + else: + output = full_output + + if debug: + print(f"[DEBUG] Final streaming output shape: {output.shape}") + + # Update cache + if full_input.shape[2] > self.context_size: + new_cache = full_input[:, :, -self.context_size:] + else: + new_cache = full_input + + if debug: + print(f"[DEBUG] New cache shape: {new_cache.shape}") + + cache.set(self.layer_id, sample_indices, new_cache) + + return output + + def _forward_non_streaming(self, x: torch.Tensor, debug: bool = False) -> torch.Tensor: + """Standard forward pass without streaming""" + if debug: + print(f"[DEBUG NON-STREAMING] Input shape: {x.shape}") + + # Apply transposed convolution + y = self.convtr(x) + + if debug: + print(f"[DEBUG NON-STREAMING] After transposed conv: {y.shape}") + + # Calculate and remove padding + if self.causal: + padding_right = math.ceil(self.padding_total * self.trim_right_ratio) + padding_left = self.padding_total - padding_right + else: + padding_right = self.padding_total // 2 + padding_left = self.padding_total - padding_right + + if padding_left + padding_right > 0: + y = unpad1d(y, (padding_left, padding_right)) + + if debug: + print(f"[DEBUG NON-STREAMING] Final output shape: {y.shape}") + + return y + +# FFN +class FFN(nn.Module): + def __init__( + self, + embed_dim, + ffn_dim, + bias=False, + ): + super().__init__() + self.embed_dim = embed_dim + self.linear1 = nn.Linear(self.embed_dim, ffn_dim, bias=bias) + self.gelu = ACT2FN["gelu"] + self.linear2 = nn.Linear(ffn_dim, self.embed_dim, bias=bias) + + def forward(self, x): + x = self.linear1(x) + x = self.gelu(x) + x = self.linear2(x) + return x + + +class Convlayer(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + pad_mode='zeros', + norm='weight_norm', + causal=True, + ): + super().__init__() + self.conv = SConv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, + groups=groups, bias=bias, pad_mode=pad_mode, norm=norm, causal=causal) + + def forward(self, x): + return self.conv(x) + +class Block1D(nn.Module): + def __init__(self, dim, kernel_size=7, drop_path=0., mixer_layer='conv', + layer_scale_init_value=1e-6, **kwargs): + super().__init__() + + if kwargs.get('layernorm', 'LN') == 'LN': + self.norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) + self.ffn_norm = ConvLayerNorm(dim, eps=kwargs.get('eps', 1e-6)) + elif kwargs.get('layernorm', 'RMSNorm') == 'RMSNorm': + self.norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) + self.ffn_norm = ConvRMSNorm(dim, eps=kwargs.get('eps', 1e-6)) + + if mixer_layer == 'conv': + self.mixer = Convlayer(dim, dim, groups=kwargs.get('groups', 1), + kernel_size=kernel_size, + pad_mode=kwargs.get('pad_mode', 'reflect'), + norm=kwargs.get('norm', 'none'), + causal=kwargs.get('causal', True), + bias=kwargs.get('bias', True), + ) + elif mixer_layer == 'depthwise_conv': + self.mixer = Convlayer(dim, dim, groups=dim, + kernel_size=kernel_size, + pad_mode=kwargs.get('pad_mode', 'reflect'), + norm=kwargs.get('norm', 'none'), + causal=kwargs.get('causal', True), + bias=kwargs.get('bias', True), + ) + else: + raise ValueError(f"Unsupported mixer layer: {mixer_layer}") + + self.ffn = FFN( + dim, + kwargs.get('ffn_expansion', 4) * dim, + bias=kwargs.get('bias', False), + ) + self.drop_path = nn.Identity() if drop_path <= 0. else nn.modules.DropPath(drop_path) + + if layer_scale_init_value > 0: + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + self.ffn_gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + else: + self.gamma = None + self.ffn_gamma = None + + def forward(self, x): + # mixer + residual = x + x = self.norm(x) + x = self.mixer(x) + if self.gamma is not None: + x = x * self.gamma.unsqueeze(-1) + x = residual + self.drop_path(x) + + # ffn + residual = x + x = self.ffn_norm(x) + x = x.permute(0, 2, 1) + x = self.ffn(x) + x = x.permute(0, 2, 1) + if self.ffn_gamma is not None: + x = x * self.ffn_gamma.unsqueeze(-1) + x = residual + self.drop_path(x) + + return x + + +class TokenizerEncoder(nn.Module): + """ + Encoder component for the VibeVoice tokenizer that converts audio to latent representations. + + Args: + config: Configuration object with model parameters + """ + def __init__(self, config): + super().__init__() + + # Extract parameters from config + self.channels = config.channels + self.dimension = config.dimension + self.n_filters = config.n_filters + self.ratios = list(reversed(config.ratios)) + self.depths = config.depths + self.n_residual_layers = getattr(config, "n_residual_layers", 1) + self.hop_length = np.prod(self.ratios) + self.causal = config.causal + + # Additional config parameters with defaults + kernel_size = getattr(config, "kernel_size", 7) + last_kernel_size = getattr(config, "last_kernel_size", 7) + norm = getattr(config, "norm", "none") + norm_params = getattr(config, "norm_params", {}) + pad_mode = getattr(config, "pad_mode", "reflect") + bias = getattr(config, "bias", True) + layernorm = getattr(config, "layernorm", "LN") + layernorm_eps = getattr(config, "layernorm_eps", 1e-6) + layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) + drop_path_rate = getattr(config, "drop_path_rate", 0.0) + mixer_layer = getattr(config, "mixer_layer", "conv") + layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) + disable_last_norm = getattr(config, "disable_last_norm", False) + + # determine the norm type based on layernorm + if layernorm == 'LN': + norm_type = ConvLayerNorm + elif layernorm == 'RMSNorm': + norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) + else: + raise ValueError(f"Unsupported norm type: {layernorm}") + + # stem and intermediate downsampling conv layers + stem = nn.Sequential( + SConv1d(self.channels, self.n_filters, kernel_size, norm=norm, norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), + ) + + self.downsample_layers = nn.ModuleList() + self.downsample_layers.append(stem) + for i in range(len(self.ratios)): + in_ch = self.n_filters * (2 ** i) + out_ch = self.n_filters * (2 ** (i + 1)) + downsample_layer = nn.Sequential( + SConv1d(in_ch, out_ch, kernel_size=self.ratios[i] * 2, stride=self.ratios[i], causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) + ) + self.downsample_layers.append(downsample_layer) + + # configure the transformer blocks + layer_type = partial( + Block1D, + mixer_layer=mixer_layer, + layernorm=layernorm, + eps=layernorm_eps, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + layer_scale_init_value=layer_scale_init_value, + ) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + + for i in range(len(self.depths)): + in_ch = self.n_filters * (2 ** i) + stage = nn.Sequential( + *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] + ) + self.stages.append(stage) + cur += self.depths[i] + + if not disable_last_norm: + self.norm = norm_type(in_ch, eps=layernorm_eps) + else: + self.norm = nn.Identity() + self.head = SConv1d(in_ch, self.dimension, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) + + def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + for i in range(len(self.depths)): + # Apply downsampling + for layer in self.downsample_layers[i]: + if isinstance(layer, SConv1d): + x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + else: + x = layer(x) + + # Apply stage (Block1D contains Convlayer which contains SConv1d) + for block in self.stages[i]: + if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): + # Block1D forward with cache support + residual = x + x = block.norm(x) + x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + if block.gamma is not None: + x = x * block.gamma.unsqueeze(-1) + x = residual + x + + # FFN part + residual = x + x = block.ffn_norm(x) + x = x.permute(0, 2, 1) + x = block.ffn(x) + x = x.permute(0, 2, 1) + if block.ffn_gamma is not None: + x = x * block.ffn_gamma.unsqueeze(-1) + x = residual + x + else: + x = block(x) + + return self.norm(x) + + def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return x + + +class TokenizerDecoder(nn.Module): + """ + Decoder component for the VibeVoice tokenizer that converts latent representations back to audio. + + Args: + config: Configuration object with model parameters + """ + def __init__(self, config): + super().__init__() + + # Extract parameters from config + self.dimension = config.dimension + self.channels = config.channels + self.n_filters = config.n_filters + self.ratios = config.ratios + + # IMPORTANT CHANGE: Don't reverse depths again since they're already reversed in VibeVoiceAcousticTokenizerModel + self.depths = config.depths # Changed from list(reversed(config.depths)) + + self.n_residual_layers = getattr(config, "n_residual_layers", 1) + self.hop_length = np.prod(self.ratios) + self.causal = config.causal + + # Additional config parameters with defaults + kernel_size = getattr(config, "kernel_size", 7) + last_kernel_size = getattr(config, "last_kernel_size", 7) + norm = getattr(config, "norm", "none") + norm_params = getattr(config, "norm_params", {}) + pad_mode = getattr(config, "pad_mode", "reflect") + bias = getattr(config, "bias", True) + layernorm = getattr(config, "layernorm", "LN") + layernorm_eps = getattr(config, "layernorm_eps", 1e-6) + trim_right_ratio = getattr(config, "trim_right_ratio", 1.0) + layernorm_elementwise_affine = getattr(config, "layernorm_elementwise_affine", True) + drop_path_rate = getattr(config, "drop_path_rate", 0.0) + mixer_layer = getattr(config, "mixer_layer", "conv") + layer_scale_init_value = getattr(config, "layer_scale_init_value", 0) + disable_last_norm = getattr(config, "disable_last_norm", False) + + # determine the norm type based on layernorm + if layernorm == 'LN': + norm_type = ConvLayerNorm + elif layernorm == 'RMSNorm': + norm_type = partial(ConvRMSNorm, elementwise_affine=layernorm_elementwise_affine) + else: + raise ValueError(f"Unsupported norm type: {layernorm}") + + # stem and upsampling layers + stem = nn.Sequential( + SConv1d(self.dimension, self.n_filters * 2 ** (len(self.depths) - 1), kernel_size, norm=norm, + norm_kwargs=norm_params, causal=self.causal, pad_mode=pad_mode, bias=bias), + ) + + self.upsample_layers = nn.ModuleList() + self.upsample_layers.append(stem) + for i in range(len(self.ratios)): + in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) + out_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i - 1)) + upsample_layer = nn.Sequential( + SConvTranspose1d(in_ch, out_ch, + kernel_size=self.ratios[i] * 2, stride=self.ratios[i], + norm=norm, norm_kwargs=norm_params, bias=bias, + causal=self.causal, trim_right_ratio=trim_right_ratio), + ) + self.upsample_layers.append(upsample_layer) + + # configure transformer blocks + layer_type = partial( + Block1D, + mixer_layer=mixer_layer, + layernorm=layernorm, + eps=layernorm_eps, + causal=self.causal, + pad_mode=pad_mode, + norm=norm, + bias=bias, + layer_scale_init_value=layer_scale_init_value, + ) + + self.stages = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))] + cur = 0 + + # Create stages in the same order as the original model + for i in range(len(self.depths)): + in_ch = self.n_filters * (2 ** (len(self.depths) - 1 - i)) + stage = nn.Sequential( + *[layer_type(dim=in_ch, drop_path=dp_rates[cur + j]) for j in range(self.depths[i])] + ) + self.stages.append(stage) + cur += self.depths[i] + + if not disable_last_norm: + self.norm = norm_type(in_ch, eps=layernorm_eps) + else: + self.norm = nn.Identity() + self.head = SConv1d(in_ch, self.channels, kernel_size=last_kernel_size, causal=self.causal, pad_mode=pad_mode, norm=norm, bias=bias) + + def forward_features(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + for i in range(len(self.depths)): + # Apply upsampling + for layer in self.upsample_layers[i]: + if isinstance(layer, (SConv1d, SConvTranspose1d)): + x = layer(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + else: + x = layer(x) + + # Apply stage (Block1D contains Convlayer which contains SConv1d) + for block in self.stages[i]: + if hasattr(block, 'mixer') and hasattr(block.mixer, 'conv') and isinstance(block.mixer.conv, SConv1d): + # Block1D forward with cache support + residual = x + x = block.norm(x) + x = block.mixer.conv(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + if block.gamma is not None: + x = x * block.gamma.unsqueeze(-1) + x = residual + x + + # FFN part + residual = x + x = block.ffn_norm(x) + x = x.permute(0, 2, 1) + x = block.ffn(x) + x = x.permute(0, 2, 1) + if block.ffn_gamma is not None: + x = x * block.ffn_gamma.unsqueeze(-1) + x = residual + x + else: + x = block(x) + + return self.norm(x) + + def forward(self, x, cache=None, sample_indices=None, use_cache=False, debug=False): + x = self.forward_features(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + x = self.head(x, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return x + + +@dataclass +class VibeVoiceTokenizerEncoderOutput: + """ + Output of VibeVoice tokenizer encoder, representing a Gaussian distribution with fixed variance. + + Args: + mean (`torch.FloatTensor`): The mean parameters of the distribution. + std (`float` or `torch.FloatTensor`): Fixed standard deviation value. + """ + mean: torch.Tensor + std: Optional[Union[float, torch.Tensor]] = None + + def sample(self, dist_type='fix'): + """ + Sample from the distribution. + + Args: + dist_type (`str`): Sampling method, either 'fix' or 'gaussian'. + + Returns: + `torch.FloatTensor`: Sampled values. + `torch.FloatTensor` (optional): Standard deviation used (only when dist_type='gaussian'). + """ + if dist_type == 'fix': + x = self.mean + self.std * torch.randn_like(self.mean) + return x, self.std + elif dist_type == 'gaussian': + batch_size = self.mean.size(0) + value = self.std / 0.8 + std = torch.randn(batch_size, device=self.mean.device, dtype=self.mean.dtype) * value + + while std.dim() < self.mean.dim(): + std = std.unsqueeze(-1) + + x = self.mean + std * torch.randn_like(self.mean) + return x, std + else: + return self.mean, self.std + + def kl(self): + """Compute KL divergence between this distribution and a standard normal.""" + target = torch.zeros_like(self.mean) + return F.mse_loss(self.mean, target, reduction='none') + + def mode(self): + """Return the distribution mode (which is the mean for Gaussian).""" + return self.mean + +class VibeVoiceAcousticTokenizerModel(PreTrainedModel): + """VibeVoice speech tokenizer model combining encoder and decoder for acoustic tokens""" + + config_class = VibeVoiceAcousticTokenizerConfig + base_model_prefix = "vibevoice_acoustic_tokenizer" + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["TokenizerEncoder", "TokenizerDecoder"] + + def __init__(self, config): + super().__init__(config) + + self.register_buffer('fix_std', torch.tensor(config.fix_std), persistent=False) + self.std_dist_type = getattr(config, "std_dist_type", "fix") + + # Parse encoder depths + if isinstance(config.encoder_depths, str): + encoder_depths = [int(d) for d in config.encoder_depths.split('-')] + else: + encoder_depths = config.encoder_depths + + # Parse decoder depths if provided + if config.decoder_depths is not None and isinstance(config.decoder_depths, str): + decoder_depths = [int(d) for d in config.decoder_depths.split('-')] + else: + # Default: use reversed encoder depths if decoder_depths is None + decoder_depths = list(reversed(encoder_depths)) + + # Create encoder config + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + + # Create decoder config + decoder_config = copy.deepcopy(config) + decoder_config.dimension = config.vae_dim + decoder_config.n_filters = config.decoder_n_filters + decoder_config.ratios = config.decoder_ratios + decoder_config.depths = decoder_depths + decoder_config.norm = config.conv_norm + decoder_config.pad_mode = config.pad_mode + decoder_config.bias = config.conv_bias + decoder_config.layernorm_eps = config.layernorm_eps + decoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine + decoder_config.mixer_layer = config.mixer_layer + decoder_config.layer_scale_init_value = config.layer_scale_init_value + decoder_config.disable_last_norm = config.disable_last_norm + + # Initialize encoder and decoder + self.encoder = TokenizerEncoder(encoder_config) + self.decoder = TokenizerDecoder(decoder_config) + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + """Initialize weights for the model""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv1d): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.no_grad() + def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Convert audio to latent representations""" + latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1), std=self.fix_std) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + """Sample from the encoder output distribution""" + dist_type = dist_type or self.std_dist_type + + if dist_type == 'fix': + return encoder_output.sample(dist_type='fix') + elif dist_type == 'gaussian': + return encoder_output.sample(dist_type='gaussian') + else: + raise ValueError(f"Unsupported dist_type: {dist_type}, expected 'fix' or 'gaussian'") + + @torch.no_grad() + def decode(self, latents, cache=None, sample_indices=None, use_cache=False, debug=False): + """Convert latent representations back to audio""" + if latents.shape[1] == self.config.vae_dim: + pass + else: + latents = latents.permute(0, 2, 1) + + audio = self.decoder(latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return audio + + def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Full forward pass: encode audio to latents, then decode back to audio""" + encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + sampled_latents, _ = self.sampling(encoder_output) + reconstructed = self.decode(sampled_latents, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return reconstructed, sampled_latents + + +class VibeVoiceSemanticTokenizerModel(PreTrainedModel): + """VibeVoice speech tokenizer model with only encoder for semantic tokens""" + + config_class = VibeVoiceSemanticTokenizerConfig + base_model_prefix = "vibevoice_semantic_tokenizer" + _supports_flash_attn_2 = True + _supports_sdpa = True + _no_split_modules = ["TokenizerEncoder"] + + def __init__(self, config): + super().__init__(config) + + # Parse encoder depths + if isinstance(config.encoder_depths, str): + encoder_depths = [int(d) for d in config.encoder_depths.split('-')] + else: + encoder_depths = config.encoder_depths + + # Create encoder config + encoder_config = copy.deepcopy(config) + encoder_config.dimension = config.vae_dim + encoder_config.n_filters = config.encoder_n_filters + encoder_config.ratios = config.encoder_ratios + encoder_config.depths = encoder_depths + encoder_config.norm = config.conv_norm + encoder_config.pad_mode = config.pad_mode + encoder_config.bias = config.conv_bias + encoder_config.layernorm_eps = config.layernorm_eps + encoder_config.layernorm_elementwise_affine = config.layernorm_elementwise_affine + encoder_config.mixer_layer = config.mixer_layer + encoder_config.layer_scale_init_value = config.layer_scale_init_value + encoder_config.disable_last_norm = config.disable_last_norm + + # Initialize encoder and decoder + self.encoder = TokenizerEncoder(encoder_config) + + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, module): + """Initialize weights for the model""" + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv1d): + nn.init.normal_(module.weight, std=self.config.weight_init_value) + if module.bias is not None: + nn.init.zeros_(module.bias) + + @torch.no_grad() + def encode(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Convert audio to latent representations""" + latents = self.encoder(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + return VibeVoiceTokenizerEncoderOutput(mean=latents.permute(0, 2, 1)) + + @torch.no_grad() + def sampling(self, encoder_output, dist_type=None): + """Sample from the encoder output distribution""" + return encoder_output.sample(dist_type='none') + + def forward(self, audio, cache=None, sample_indices=None, use_cache=False, debug=False): + """Full forward pass: encode audio to latents, then decode back to audio""" + encoder_output = self.encode(audio, cache=cache, sample_indices=sample_indices, use_cache=use_cache, debug=debug) + sampled_latents, _ = self.sampling(encoder_output, dist_type='none') + return None, sampled_latents + +AutoModel.register(VibeVoiceAcousticTokenizerConfig, VibeVoiceAcousticTokenizerModel) +AutoModel.register(VibeVoiceSemanticTokenizerConfig, VibeVoiceSemanticTokenizerModel) + +__all__ = [ + "VibeVoiceTokenizerStreamingCache", + "VibeVoiceAcousticTokenizerModel", + "VibeVoiceSemanticTokenizerModel", +] \ No newline at end of file diff --git a/vibevoice/modular/streamer.py b/vibevoice/modular/streamer.py new file mode 100644 index 0000000..7a76cb0 --- /dev/null +++ b/vibevoice/modular/streamer.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import torch + +import asyncio +from queue import Queue +from typing import TYPE_CHECKING, Optional + + +from transformers.generation import BaseStreamer + + +class AudioStreamer(BaseStreamer): + """ + Audio streamer that stores audio chunks in queues for each sample in the batch. + This allows streaming audio generation for multiple samples simultaneously. + + Parameters: + batch_size (`int`): + The batch size for generation + stop_signal (`any`, *optional*): + The signal to put in the queue when generation ends. Defaults to None. + timeout (`float`, *optional*): + The timeout for the audio queue. If `None`, the queue will block indefinitely. + """ + + def __init__( + self, + batch_size: int, + stop_signal: Optional[any] = None, + timeout: Optional[float] = None, + ): + self.batch_size = batch_size + self.stop_signal = stop_signal + self.timeout = timeout + + # Create a queue for each sample in the batch + self.audio_queues = [Queue() for _ in range(batch_size)] + self.finished_flags = [False for _ in range(batch_size)] + self.sample_indices_map = {} # Maps from sample index to queue index + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): + """ + Receives audio chunks and puts them in the appropriate queues. + + Args: + audio_chunks: Tensor of shape (num_samples, ...) containing audio chunks + sample_indices: Tensor indicating which samples these chunks belong to + """ + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and not self.finished_flags[idx]: + # Convert to numpy or keep as tensor based on preference + audio_chunk = audio_chunks[i].detach().cpu() + self.audio_queues[idx].put(audio_chunk, timeout=self.timeout) + + def end(self, sample_indices: Optional[torch.Tensor] = None): + """ + Signals the end of generation for specified samples or all samples. + + Args: + sample_indices: Optional tensor of sample indices to end. If None, ends all. + """ + if sample_indices is None: + # End all samples + for idx in range(self.batch_size): + if not self.finished_flags[idx]: + self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) + self.finished_flags[idx] = True + else: + # End specific samples + for sample_idx in sample_indices: + idx = sample_idx.item() if torch.is_tensor(sample_idx) else sample_idx + if idx < self.batch_size and not self.finished_flags[idx]: + self.audio_queues[idx].put(self.stop_signal, timeout=self.timeout) + self.finished_flags[idx] = True + + def __iter__(self): + """Returns an iterator over the batch of audio streams.""" + return AudioBatchIterator(self) + + def get_stream(self, sample_idx: int): + """Get the audio stream for a specific sample.""" + if sample_idx >= self.batch_size: + raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") + return AudioSampleIterator(self, sample_idx) + + +class AudioSampleIterator: + """Iterator for a single audio stream from the batch.""" + + def __init__(self, streamer: AudioStreamer, sample_idx: int): + self.streamer = streamer + self.sample_idx = sample_idx + + def __iter__(self): + return self + + def __next__(self): + value = self.streamer.audio_queues[self.sample_idx].get(timeout=self.streamer.timeout) + if value == self.streamer.stop_signal: + raise StopIteration() + return value + + +class AudioBatchIterator: + """Iterator that yields audio chunks for all samples in the batch.""" + + def __init__(self, streamer: AudioStreamer): + self.streamer = streamer + self.active_samples = set(range(streamer.batch_size)) + + def __iter__(self): + return self + + def __next__(self): + if not self.active_samples: + raise StopIteration() + + batch_chunks = {} + samples_to_remove = set() + + # Try to get chunks from all active samples + for idx in self.active_samples: + try: + value = self.streamer.audio_queues[idx].get(block=False) + if value == self.streamer.stop_signal: + samples_to_remove.add(idx) + else: + batch_chunks[idx] = value + except: + # Queue is empty for this sample, skip it this iteration + pass + + # Remove finished samples + self.active_samples -= samples_to_remove + + if batch_chunks: + return batch_chunks + elif self.active_samples: + # If no chunks were ready but we still have active samples, + # wait a bit and try again + import time + time.sleep(0.01) + return self.__next__() + else: + raise StopIteration() + + +class AsyncAudioStreamer(AudioStreamer): + """ + Async version of AudioStreamer for use in async contexts. + """ + + def __init__( + self, + batch_size: int, + stop_signal: Optional[any] = None, + timeout: Optional[float] = None, + ): + super().__init__(batch_size, stop_signal, timeout) + # Replace regular queues with async queues + self.audio_queues = [asyncio.Queue() for _ in range(batch_size)] + self.loop = asyncio.get_running_loop() + + def put(self, audio_chunks: torch.Tensor, sample_indices: torch.Tensor): + """Put audio chunks in the appropriate async queues.""" + for i, sample_idx in enumerate(sample_indices): + idx = sample_idx.item() + if idx < self.batch_size and not self.finished_flags[idx]: + audio_chunk = audio_chunks[i].detach().cpu() + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, audio_chunk + ) + + def end(self, sample_indices: Optional[torch.Tensor] = None): + """Signal the end of generation for specified samples.""" + if sample_indices is None: + indices_to_end = range(self.batch_size) + else: + indices_to_end = [s.item() if torch.is_tensor(s) else s for s in sample_indices] + + for idx in indices_to_end: + if idx < self.batch_size and not self.finished_flags[idx]: + self.loop.call_soon_threadsafe( + self.audio_queues[idx].put_nowait, self.stop_signal + ) + self.finished_flags[idx] = True + + async def get_stream(self, sample_idx: int): + """Get async iterator for a specific sample's audio stream.""" + if sample_idx >= self.batch_size: + raise ValueError(f"Sample index {sample_idx} exceeds batch size {self.batch_size}") + + while True: + value = await self.audio_queues[sample_idx].get() + if value == self.stop_signal: + break + yield value + + def __aiter__(self): + """Returns an async iterator over all audio streams.""" + return AsyncAudioBatchIterator(self) + + +class AsyncAudioBatchIterator: + """Async iterator for batch audio streaming.""" + + def __init__(self, streamer: AsyncAudioStreamer): + self.streamer = streamer + self.active_samples = set(range(streamer.batch_size)) + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.active_samples: + raise StopAsyncIteration() + + batch_chunks = {} + samples_to_remove = set() + + # Create tasks for all active samples + tasks = { + idx: asyncio.create_task(self._get_chunk(idx)) + for idx in self.active_samples + } + + # Wait for at least one chunk to be ready + done, pending = await asyncio.wait( + tasks.values(), + return_when=asyncio.FIRST_COMPLETED, + timeout=self.streamer.timeout + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + + # Process completed tasks + for idx, task in tasks.items(): + if task in done: + try: + value = await task + if value == self.streamer.stop_signal: + samples_to_remove.add(idx) + else: + batch_chunks[idx] = value + except asyncio.CancelledError: + pass + + self.active_samples -= samples_to_remove + + if batch_chunks: + return batch_chunks + elif self.active_samples: + # Try again if we still have active samples + return await self.__anext__() + else: + raise StopAsyncIteration() + + async def _get_chunk(self, idx): + """Helper to get a chunk from a specific queue.""" + return await self.streamer.audio_queues[idx].get() \ No newline at end of file diff --git a/vibevoice/processor/__init__.py b/vibevoice/processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vibevoice/processor/vibevoice_processor.py b/vibevoice/processor/vibevoice_processor.py new file mode 100644 index 0000000..66d0a9d --- /dev/null +++ b/vibevoice/processor/vibevoice_processor.py @@ -0,0 +1,677 @@ +import math +import warnings +from typing import List, Optional, Union, Dict, Any, Tuple +import os +import re + +import numpy as np +import torch + +from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType, logging +from .vibevoice_tokenizer_processor import AudioNormalizer + +logger = logging.get_logger(__name__) + + +class VibeVoiceProcessor: + r""" + Constructs a VibeVoice processor which wraps a VibeVoice tokenizer and audio processor into a single processor. + + [`VibeVoiceProcessor`] offers all the functionalities of [`VibeVoiceTokenizer`] and [`VibeVoiceTokenizerProcessor`]. + See the [`~VibeVoiceProcessor.__call__`] and [`~VibeVoiceProcessor.decode`] for more information. + + Args: + tokenizer (`VibeVoiceTextTokenizer` or `VibeVoiceTextTokenizerFast`): + The tokenizer for text processing. + audio_processor (`VibeVoiceTokenizerProcessor`): + The audio processor for speech processing. + speech_tok_compress_ratio (`int`, *optional*, defaults to 3200): + The compression ratio for speech tokenization. + db_normalize (`bool`, *optional*, defaults to True): + Whether to apply decibel normalization to audio inputs. + """ + + def __init__(self, tokenizer=None, audio_processor=None, speech_tok_compress_ratio=3200, db_normalize=True, **kwargs): + self.tokenizer = tokenizer + self.audio_processor = audio_processor + self.speech_tok_compress_ratio = speech_tok_compress_ratio + self.db_normalize = db_normalize + self.audio_normalizer = AudioNormalizer() if db_normalize else None + self.system_prompt = " Transform the text provided by various speakers into speech output, utilizing the distinct voice of each respective speaker.\n" + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + """ + Instantiate a VibeVoiceProcessor from a pretrained VibeVoice processor. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + - a string, the *model id* of a pretrained model + - a path to a *directory* containing processor config + + Returns: + [`VibeVoiceProcessor`]: The processor object instantiated from pretrained model. + """ + import os + import json + from .vibevoice_tokenizer_processor import VibeVoiceTokenizerProcessor + from vibevoice.modular.modular_vibevoice_text_tokenizer import ( + VibeVoiceTextTokenizer, + VibeVoiceTextTokenizerFast + ) + + # Load processor configuration + config_path = os.path.join(pretrained_model_name_or_path, "preprocessor_config.json") + if os.path.exists(config_path): + with open(config_path, 'r') as f: + config = json.load(f) + else: + logger.warning(f"No preprocessor_config.json found at {pretrained_model_name_or_path}, using defaults") + config = { + "speech_tok_compress_ratio": 3200, + "db_normalize": True, + } + + # Extract main processor parameters + speech_tok_compress_ratio = config.get("speech_tok_compress_ratio", 3200) + db_normalize = config.get("db_normalize", True) + + # Load tokenizer - try from model path first, then fallback to Qwen + language_model_pretrained_name = config.get("language_model_pretrained_name", None) or kwargs.pop("language_model_pretrained_name", "Qwen/Qwen2.5-1.5B") + logger.info(f"Loading tokenizer from {language_model_pretrained_name}") + if 'qwen' in language_model_pretrained_name.lower(): + tokenizer = VibeVoiceTextTokenizerFast.from_pretrained( + language_model_pretrained_name, + **kwargs + ) + else: + raise ValueError(f"Unsupported tokenizer type for {language_model_pretrained_name}. Supported types: Qwen, Llama, Gemma.") + + # Load audio processor + if "audio_processor" in config: + # Create audio processor from config + audio_config = config["audio_processor"] + audio_processor = VibeVoiceTokenizerProcessor( + sampling_rate=audio_config.get("sampling_rate", 24000), + normalize_audio=audio_config.get("normalize_audio", True), + target_dB_FS=audio_config.get("target_dB_FS", -25), + eps=audio_config.get("eps", 1e-6), + ) + else: + # Create default audio processor + audio_processor = VibeVoiceTokenizerProcessor() + + # Create and return the processor + return cls( + tokenizer=tokenizer, + audio_processor=audio_processor, + speech_tok_compress_ratio=speech_tok_compress_ratio, + db_normalize=db_normalize, + ) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + """ + Save a processor to a directory, so that it can be re-loaded using the + [`~VibeVoiceProcessor.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the processor will be saved. + """ + import os + import json + + os.makedirs(save_directory, exist_ok=True) + + # Save processor configuration + processor_config = { + "processor_class": "VibeVoiceProcessor", + "speech_tok_compress_ratio": self.speech_tok_compress_ratio, + "db_normalize": self.db_normalize, + "audio_processor": { + "feature_extractor_type": "VibeVoiceTokenizerProcessor", + "sampling_rate": getattr(self.audio_processor, 'sampling_rate', 24000), + "normalize_audio": getattr(self.audio_processor, 'normalize_audio', True), + "target_dB_FS": getattr(self.audio_processor, 'target_dB_FS', -25), + "eps": getattr(self.audio_processor, 'eps', 1e-6), + } + } + + config_path = os.path.join(save_directory, "preprocessor_config.json") + with open(config_path, 'w') as f: + json.dump(processor_config, f, indent=2) + + logger.info(f"Processor configuration saved in {config_path}") + + def __call__( + self, + text: Optional[Union[str, List[str], TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, + voice_samples: Optional[Union[List[Union[str, np.ndarray]], List[List[Union[str, np.ndarray]]]]] = None, + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to process one or more podcast scripts with optional voice samples. + + Args: + text (`str`, `List[str]`): + The input text(s) to process. Can be: + - A single script string + - A list of script strings for batch processing + - A path to a .json or .txt file + - A list of paths + voice_samples (`List[Union[str, np.ndarray]]`, `List[List[Union[str, np.ndarray]]]`, *optional*): + Voice samples for each script. Can be: + - A list of samples for a single script + - A list of lists for batch processing + padding (`bool`, `str` or `PaddingStrategy`, defaults to `True`): + Whether to pad sequences to the same length + truncation (`bool`, `str` or `TruncationStrategy`, defaults to `False`): + Whether to truncate sequences + max_length (`int`, *optional*): + Maximum length of the returned sequences + return_tensors (`str` or `TensorType`, *optional*): + If set, will return tensors of a particular framework + return_attention_mask (`bool`, defaults to `True`): + Whether to return the attention mask + + Returns: + `BatchEncoding`: A BatchEncoding with the following fields: + - **input_ids** -- List of token id sequences or tensor + - **attention_mask** -- List of attention masks or tensor + - **speech_tensors** -- Padded speech inputs (if voice_samples provided) + - **speech_masks** -- Speech masks (if voice_samples provided) + - **speech_input_mask** -- Boolean masks indicating speech token positions + """ + # Handle single vs batch input + if isinstance(text, str) or (isinstance(text, list) and len(text) > 0 and not isinstance(text[0], str)): + # Single input + texts = [text] + is_batched = False + else: + # Batch input + texts = text + is_batched = True + + # Handle voice samples + if voice_samples is not None: + if not is_batched or (isinstance(voice_samples[0], (str, np.ndarray))): + # Single set of voice samples + voice_samples_list = [voice_samples] + else: + # Batch of voice samples + voice_samples_list = voice_samples + else: + voice_samples_list = [None] * len(texts) + + # Process each input + all_encodings = [] + for text_input, voice_input in zip(texts, voice_samples_list): + encoding = self._process_single(text_input, voice_input) + all_encodings.append(encoding) + + # Combine batch + batch_encoding = self._batch_encode( + all_encodings, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + return_attention_mask=return_attention_mask, + ) + + return batch_encoding + + def _process_single( + self, + text: Union[str, TextInput], + voice_samples: Optional[List[Union[str, np.ndarray]]] = None, + ) -> Dict[str, Any]: + """Process a single podcast script.""" + # Determine if text is a file path or direct script + script = None + if isinstance(text, str): + # Check if it's a file path + if text.endswith('.json') and os.path.exists(text): + script = self._convert_json_to_script(text) + elif text.endswith('.txt') and os.path.exists(text): + script = self._convert_text_to_script(text) + else: + # Assume it's the script content directly + script = text + + if script is None: + raise ValueError(f"Could not process input text: {text}") + + # Parse the script + parsed_lines = self._parse_script(script) + all_speakers = list(set(speaker_id for speaker_id, _ in parsed_lines)) + + # Create system prompt + # system_tokens = self.tokenizer.encode(self.system_prompt, add_special_tokens=False) + system_tokens = self.tokenizer.encode(self.system_prompt) + + # Process voice samples if provided + if voice_samples: + voice_tokens, voice_speech_inputs, voice_speech_masks = self._create_voice_prompt(voice_samples[:len(all_speakers)]) + else: + voice_tokens, voice_speech_inputs, voice_speech_masks = [], [], [] + + # Build full token sequence + full_tokens = system_tokens + voice_tokens + speech_input_mask = [False] * len(system_tokens) + voice_speech_masks + + # Add text input section + full_tokens += self.tokenizer.encode(' Text input:\n', add_special_tokens=False) + speech_input_mask += [False] * len(self.tokenizer.encode(' Text input:\n', add_special_tokens=False)) + + for speaker_id, speaker_text in parsed_lines: + speaker_text_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:{speaker_text}\n", add_special_tokens=False) + full_tokens += speaker_text_tokens + speech_input_mask += [False] * len(speaker_text_tokens) + + # Add speech output section + full_tokens += self.tokenizer.encode(' Speech output:\n', add_special_tokens=False) + [self.tokenizer.speech_start_id] + speech_input_mask += [False] * (len(self.tokenizer.encode(' Speech output:\n', add_special_tokens=False)) + 1) + + return { + "input_ids": full_tokens, + "speech_inputs": voice_speech_inputs if voice_speech_inputs else None, + "speech_input_mask": speech_input_mask, + "parsed_script": parsed_lines, + "all_speakers": all_speakers, + } + + def _batch_encode( + self, + encodings: List[Dict[str, Any]], + padding: Union[bool, str, PaddingStrategy] = True, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_attention_mask: bool = True, + ) -> BatchEncoding: + """Combine multiple encodings into a batch with padding.""" + # Extract input_ids and create attention_mask + input_ids_list = [enc["input_ids"] for enc in encodings] + speech_input_masks_list = [enc["speech_input_mask"] for enc in encodings] + + # Determine padding strategy + if isinstance(padding, bool): + padding_strategy = PaddingStrategy.LONGEST if padding else PaddingStrategy.DO_NOT_PAD + elif isinstance(padding, str): + padding_strategy = PaddingStrategy(padding) + else: + padding_strategy = padding + + # Apply padding to input_ids + if padding_strategy != PaddingStrategy.DO_NOT_PAD: + if padding_strategy == PaddingStrategy.LONGEST: + max_len = max(len(ids) for ids in input_ids_list) + elif padding_strategy == PaddingStrategy.MAX_LENGTH and max_length is not None: + max_len = max_length + else: + max_len = max(len(ids) for ids in input_ids_list) + + # Pad sequences + padded_input_ids = [] + attention_masks = [] + padded_speech_input_masks = [] + + for input_ids, speech_mask in zip(input_ids_list, speech_input_masks_list): + # Truncate if needed + if truncation and len(input_ids) > max_len: + input_ids = input_ids[:max_len] + speech_mask = speech_mask[:max_len] + + # Pad + padding_length = max_len - len(input_ids) + # padded_ids = [self.tokenizer.pad_token_id] * padding_length + input_ids + padded_ids = [self.tokenizer.pad_id] * padding_length + input_ids + attention_mask = [0] * padding_length + [1] * len(input_ids) + padded_speech_mask = [False] * padding_length + speech_mask + + padded_input_ids.append(padded_ids) + attention_masks.append(attention_mask) + padded_speech_input_masks.append(padded_speech_mask) + + input_ids_list = padded_input_ids + speech_input_masks_list = padded_speech_input_masks + else: + # No padding, just create attention masks + attention_masks = [[1] * len(ids) for ids in input_ids_list] if return_attention_mask else None + + # Process speech inputs + all_speech_inputs = [] + has_speech = False + for enc in encodings: + if enc["speech_inputs"] is not None: + all_speech_inputs.extend(enc["speech_inputs"]) + has_speech = True + + # Prepare batch encoding + batch_encoding = BatchEncoding() + + # Handle tensor conversion + if return_tensors is not None: + batch_encoding["input_ids"] = torch.tensor(input_ids_list, dtype=torch.long) + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = torch.tensor(attention_masks, dtype=torch.long) + batch_encoding["speech_input_mask"] = torch.tensor(speech_input_masks_list, dtype=torch.bool) + else: + batch_encoding["input_ids"] = input_ids_list + if return_attention_mask and attention_masks is not None: + batch_encoding["attention_mask"] = attention_masks + batch_encoding["speech_input_mask"] = speech_input_masks_list + + # Process speech tensors if present + if has_speech: + speech_dict = self.prepare_speech_inputs( + all_speech_inputs, + return_tensors=return_tensors, + ) + batch_encoding["speech_tensors"] = speech_dict["padded_speeches"] + batch_encoding["speech_masks"] = speech_dict["speech_masks"] + else: + batch_encoding["speech_tensors"] = None + batch_encoding["speech_masks"] = None + + # Add metadata + batch_encoding["parsed_scripts"] = [enc["parsed_script"] for enc in encodings] + batch_encoding["all_speakers_list"] = [enc["all_speakers"] for enc in encodings] + + return batch_encoding + + def _create_voice_prompt( + self, + speaker_samples: List[Union[str, np.ndarray]] + ) -> Tuple[List[int], List[np.ndarray], List[bool]]: + """ + Create voice prompt tokens and process audio samples. + + Returns: + tuple: (voice_tokens, voice_speech_inputs, voice_speech_masks) + """ + vae_token_id = self.tokenizer.speech_diffusion_id + + voice_full_tokens = self.tokenizer.encode(' Voice input:\n', add_special_tokens=False) + voice_speech_inputs = [] + voice_speech_masks = [False] * len(voice_full_tokens) + + for speaker_id, speaker_audio in enumerate(speaker_samples): + prefix_tokens = self.tokenizer.encode(f" Speaker {speaker_id}:", add_special_tokens=False) + + # Process audio + if isinstance(speaker_audio, str): + # Load audio from file + wav = self.audio_processor._load_audio_from_path(speaker_audio) + else: + wav = np.array(speaker_audio, dtype=np.float32) + + # Apply normalization if needed + if self.db_normalize and self.audio_normalizer: + wav = self.audio_normalizer(wav) + + # Calculate token length based on compression ratio + # if speaker_audio.endswith('.pt') or speaker_audio.endswith('.npy'): + # vae_tok_len = wav.shape[0] + # else: + vae_tok_len = math.ceil(wav.shape[0] / self.speech_tok_compress_ratio) + + # Build tokens and masks + speaker_tokens = (prefix_tokens + + [self.tokenizer.speech_start_id] + + [vae_token_id] * vae_tok_len + + [self.tokenizer.speech_end_id] + + self.tokenizer.encode('\n', add_special_tokens=False)) + + vae_input_mask = ([False] * len(prefix_tokens) + + [False] + + [True] * vae_tok_len + + [False] + + [False]) + + voice_full_tokens.extend(speaker_tokens) + voice_speech_masks.extend(vae_input_mask) + voice_speech_inputs.append(wav) + + return voice_full_tokens, voice_speech_inputs, voice_speech_masks + + def prepare_speech_inputs( + self, + speech_inputs: List[np.ndarray], + return_tensors: Optional[Union[str, TensorType]] = None, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[torch.dtype] = None, + ) -> Dict[str, Any]: + """ + Prepare speech inputs for model consumption. + + Args: + speech_inputs: List of speech arrays + return_tensors: Output tensor type + device: Device to place tensors on + dtype: Data type for tensors + + Returns: + Dictionary with padded_speeches and speech_masks + """ + if not speech_inputs: + return {"padded_speeches": None, "speech_masks": None} + + # Calculate sequence lengths + vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) for s in speech_inputs] + # vae_tok_seqlens = [math.ceil(s.shape[0] / self.speech_tok_compress_ratio) if s.ndim == 1 else s.shape[0] for s in speech_inputs] + max_speech_length = max(s.shape[0] for s in speech_inputs) + + # Pad speeches + if speech_inputs[0].ndim == 1: + padded_speeches = np.full((len(speech_inputs), max_speech_length), fill_value=0, dtype=np.float32) + else: + padded_speeches = np.full((len(speech_inputs), max_speech_length, speech_inputs[0].shape[-1]), fill_value=0, dtype=np.float32) + speech_masks = np.zeros((len(speech_inputs), max(vae_tok_seqlens)), dtype=np.bool_) + + for i, (speech, vae_tok_length) in enumerate(zip(speech_inputs, vae_tok_seqlens)): + padded_speeches[i, :len(speech)] = speech + speech_masks[i, :vae_tok_length] = True + + result = { + "padded_speeches": padded_speeches, + "speech_masks": speech_masks, + } + + # Convert to tensors if requested + if return_tensors == "pt": + result["padded_speeches"] = torch.tensor(padded_speeches, device=device, dtype=dtype or torch.float32) + result["speech_masks"] = torch.tensor(speech_masks, device=device, dtype=torch.bool) + + return result + + def _convert_json_to_script(self, json_file: str) -> str: + """ + Convert JSON format to script format. + Expected JSON format: + [ + {"speaker": "1", "text": "Hello everyone..."}, + {"speaker": "2", "text": "Great to be here..."} + ] + """ + import json + + with open(json_file, 'r', encoding='utf-8') as f: + data = json.load(f) + + if not isinstance(data, list): + raise ValueError("JSON file must contain a list of speaker entries") + + script_lines = [] + for item in data: + if not isinstance(item, dict): + logger.warning(f"Skipping non-dict entry: {item}") + continue + + speaker = item.get('speaker') + text = item.get('text') + + if speaker is None or text is None: + logger.warning(f"Skipping entry missing speaker or text: {item}") + continue + + # Ensure speaker ID is valid + try: + speaker_id = int(speaker) + except (ValueError, TypeError): + logger.warning(f"Invalid speaker ID: {speaker}, skipping entry") + continue + + # Clean up text + text = text.strip() + if text: + script_lines.append(f"Speaker {speaker_id}: {text}") + + if not script_lines: + raise ValueError("No valid entries found in JSON file") + + return "\n".join(script_lines) + + def _convert_text_to_script(self, text_file: str) -> str: + """ + Convert text file to script format. + Handles multiple formats: + 1. Already formatted as "Speaker X: text" + 2. Plain text (assigns to Speaker 1) + + Handles edge cases like multiple colons in a line. + """ + with open(text_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + script_lines = [] + current_speaker = 1 + + for line in lines: + line = line.strip() + if not line: + continue + + # Try to parse as "Speaker X: text" format + # Use regex to be more robust + speaker_match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line, re.IGNORECASE) + + if speaker_match: + speaker_id = int(speaker_match.group(1)) + text = speaker_match.group(2).strip() + if text: + script_lines.append(f"Speaker {speaker_id}: {text}") + else: + # Treat as plain text - assign to current speaker + script_lines.append(f"Speaker {current_speaker}: {line}") + + if not script_lines: + raise ValueError("No valid content found in text file") + + return "\n".join(script_lines) + + def _parse_script(self, script: str) -> List[Tuple[int, str]]: + """Parse script into list of (speaker_id, text) tuples.""" + lines = script.strip().split("\n") + parsed_lines = [] + speaker_ids = [] + + # First pass: parse all lines and collect speaker IDs + for line in lines: + if not line.strip(): + continue + + # Use regex to handle edge cases like multiple colons + match = re.match(r'^Speaker\s+(\d+)\s*:\s*(.*)$', line.strip(), re.IGNORECASE) + + if match: + speaker_id = int(match.group(1)) + text = ' ' + match.group(2).strip() + parsed_lines.append((speaker_id, text)) + speaker_ids.append(speaker_id) + else: + logger.warning(f"Could not parse line: '{line}'") + + if not parsed_lines: + raise ValueError("No valid speaker lines found in script") + + # Check if we need to normalize speaker IDs (only if all are > 0) + min_speaker_id = min(speaker_ids) + if min_speaker_id > 0: + # Normalize to start from 0 + normalized_lines = [] + for speaker_id, text in parsed_lines: + normalized_lines.append((speaker_id - 1, text)) + return normalized_lines + else: + # Keep original IDs + return parsed_lines + + def _merge_inputs(self, text_inputs: BatchEncoding, audio_inputs: Dict) -> BatchEncoding: + """Merge text and audio inputs into a single BatchEncoding.""" + # Start with text inputs + merged = BatchEncoding(text_inputs) + + # Add audio-specific fields + if "audio" in audio_inputs: + merged["speech_inputs"] = audio_inputs["audio"] + if "streaming" in audio_inputs: + merged["streaming"] = audio_inputs["streaming"] + + return merged + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.batch_decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to VibeVoiceTextTokenizer's [`~PreTrainedTokenizer.decode`]. + Please refer to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + """ + Return the list of inputs accepted by the model. + """ + tokenizer_input_names = self.tokenizer.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + audio_processor_input_names + ["speech_inputs", "speech_input_mask"])) + + def save_audio(self, + audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], + output_path: str = "output.wav", + sampling_rate: Optional[int] = None, + normalize: bool = False, + batch_prefix: str = "audio_", + ) -> str: + """ + Save audio data to a file. + Args: + audio (Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]): + The audio data to save. Can be a single tensor/array or a list of them. + output_path (str, optional): Path to save the audio file. Defaults to "output.wav". + sampling_rate (int, optional): Sampling rate for the audio. If None, uses the processor's default. + normalize (bool, optional): Whether to normalize the audio before saving. Defaults to False. + batch_prefix (str, optional): Prefix for batch audio files. Defaults to "audio_". + Returns: + str: The path to the saved audio file. + """ + return self.audio_processor.save_audio(audio, output_path=output_path, sampling_rate=sampling_rate, normalize=normalize, batch_prefix=batch_prefix) + +__all__ = [ + "VibeVoiceProcessor", +] \ No newline at end of file diff --git a/vibevoice/processor/vibevoice_tokenizer_processor.py b/vibevoice/processor/vibevoice_tokenizer_processor.py new file mode 100644 index 0000000..0d854b7 --- /dev/null +++ b/vibevoice/processor/vibevoice_tokenizer_processor.py @@ -0,0 +1,483 @@ +""" +Processor class for VibeVoice models. +""" + +import os +import json +import warnings +from typing import List, Optional, Union, Dict, Any + +import numpy as np +import torch + +from transformers.feature_extraction_utils import FeatureExtractionMixin +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class AudioNormalizer: + """ + Audio normalization class for VibeVoice tokenizer. + + This class provides audio normalization to ensure consistent input levels + for the VibeVoice tokenizer while maintaining audio quality. + """ + + def __init__(self, target_dB_FS: float = -25, eps: float = 1e-6): + """ + Initialize the audio normalizer. + + Args: + target_dB_FS (float): Target dB FS level for the audio. Default: -25 + eps (float): Small value to avoid division by zero. Default: 1e-6 + """ + self.target_dB_FS = target_dB_FS + self.eps = eps + + def tailor_dB_FS(self, audio: np.ndarray) -> tuple: + """ + Adjust the audio to the target dB FS level. + + Args: + audio (np.ndarray): Input audio signal + + Returns: + tuple: (normalized_audio, rms, scalar) + """ + rms = np.sqrt(np.mean(audio**2)) + scalar = 10 ** (self.target_dB_FS / 20) / (rms + self.eps) + normalized_audio = audio * scalar + return normalized_audio, rms, scalar + + def avoid_clipping(self, audio: np.ndarray, scalar: Optional[float] = None) -> tuple: + """ + Avoid clipping by scaling down if necessary. + + Args: + audio (np.ndarray): Input audio signal + scalar (float, optional): Explicit scaling factor + + Returns: + tuple: (normalized_audio, scalar) + """ + if scalar is None: + max_val = np.max(np.abs(audio)) + if max_val > 1.0: + scalar = max_val + self.eps + else: + scalar = 1.0 + + return audio / scalar, scalar + + def __call__(self, audio: np.ndarray) -> np.ndarray: + """ + Normalize the audio by adjusting to target dB FS and avoiding clipping. + + Args: + audio (np.ndarray): Input audio signal + + Returns: + np.ndarray: Normalized audio signal + """ + # First adjust to target dB FS + audio, _, _ = self.tailor_dB_FS(audio) + # Then avoid clipping + audio, _ = self.avoid_clipping(audio) + return audio + + +# Change from ProcessorMixin to FeatureExtractionMixin which is designed for single components +class VibeVoiceTokenizerProcessor(FeatureExtractionMixin): + """ + Processor for VibeVoice acoustic tokenizer models. + + This processor handles audio preprocessing for VibeVoice models, including: + - Audio format conversion (stereo to mono) + - Optional audio normalization + - Streaming support for infinite-length audio + + Args: + sampling_rate (int, optional): Expected sampling rate. Defaults to 24000. + normalize_audio (bool, optional): Whether to normalize audio. Defaults to True. + target_dB_FS (float, optional): Target dB FS for normalization. Defaults to -25. + eps (float, optional): Small value for numerical stability. Defaults to 1e-6. + """ + model_input_names = ["input_features"] + + def __init__( + self, + sampling_rate: int = 24000, + normalize_audio: bool = True, + target_dB_FS: float = -25, + eps: float = 1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + self.sampling_rate = sampling_rate + self.normalize_audio = normalize_audio + + # Initialize audio normalizer if needed + if self.normalize_audio: + self.normalizer = AudioNormalizer(target_dB_FS=target_dB_FS, eps=eps) + else: + self.normalizer = None + + # Save config + self.feature_extractor_dict = { + "sampling_rate": sampling_rate, + "normalize_audio": normalize_audio, + "target_dB_FS": target_dB_FS, + "eps": eps, + } + + def _ensure_mono(self, audio: np.ndarray) -> np.ndarray: + """ + Convert stereo audio to mono if needed. + + Args: + audio (np.ndarray): Input audio array + + Returns: + np.ndarray: Mono audio array + """ + if len(audio.shape) == 1: + return audio + elif len(audio.shape) == 2: + if audio.shape[0] == 2: # (2, time) + return np.mean(audio, axis=0) + elif audio.shape[1] == 2: # (time, 2) + return np.mean(audio, axis=1) + else: + # If one dimension is 1, squeeze it + if audio.shape[0] == 1: + return audio.squeeze(0) + elif audio.shape[1] == 1: + return audio.squeeze(1) + else: + raise ValueError(f"Unexpected audio shape: {audio.shape}") + else: + raise ValueError(f"Audio should be 1D or 2D, got shape: {audio.shape}") + + def _process_single_audio(self, audio: Union[np.ndarray, List[float]]) -> np.ndarray: + """ + Process a single audio array. + + Args: + audio: Single audio input + + Returns: + np.ndarray: Processed audio + """ + # Convert to numpy array + if not isinstance(audio, np.ndarray): + audio = np.array(audio, dtype=np.float32) + else: + audio = audio.astype(np.float32) + + # Ensure mono + audio = self._ensure_mono(audio) + + # Normalize if requested + if self.normalize_audio and self.normalizer is not None: + audio = self.normalizer(audio) + + return audio + + def __call__( + self, + audio: Union[str, np.ndarray, List[float], List[np.ndarray], List[List[float]], List[str]] = None, + sampling_rate: Optional[int] = None, + return_tensors: Optional[str] = None, + **kwargs, + ): + """ + Process audio for VibeVoice models. + + Args: + audio: Audio input(s) to process. Can be: + - str: Path to audio file + - np.ndarray: Audio array + - List[float]: Audio as list of floats + - List[np.ndarray]: Batch of audio arrays + - List[str]: Batch of audio file paths + sampling_rate (int, optional): Sampling rate of the input audio + return_tensors (str, optional): Return format ('pt' for PyTorch, 'np' for NumPy) + + Returns: + dict: Processed audio inputs with keys: + - input_features: Audio tensor(s) ready for the model + """ + if audio is None: + raise ValueError("Audio input is required") + + # Validate sampling rate + if sampling_rate is not None and sampling_rate != self.sampling_rate: + logger.warning( + f"Input sampling rate ({sampling_rate}) differs from expected " + f"sampling rate ({self.sampling_rate}). Please resample your audio." + ) + + # Handle different input types + if isinstance(audio, str): + # Single audio file path + audio = self._load_audio_from_path(audio) + is_batched = False + elif isinstance(audio, list): + if len(audio) == 0: + raise ValueError("Empty audio list provided") + + # Check if it's a list of file paths + if all(isinstance(item, str) for item in audio): + # Batch of audio file paths + audio = [self._load_audio_from_path(path) for path in audio] + is_batched = True + else: + # Check if it's batched audio arrays + is_batched = isinstance(audio[0], (np.ndarray, list)) + else: + # Single audio array or list + is_batched = False + + # Process audio + if is_batched: + processed_audio = [self._process_single_audio(a) for a in audio] + else: + processed_audio = [self._process_single_audio(audio)] + + # Convert to tensors if requested + if return_tensors == "pt": + if len(processed_audio) == 1: + # Create a proper batch dimension (B, T) + input_features = torch.from_numpy(processed_audio[0]).unsqueeze(0).unsqueeze(1) + else: + # For batched input with different lengths, create a batch properly + input_features = torch.stack([torch.from_numpy(a) for a in processed_audio]).unsqueeze(1) + elif return_tensors == "np": + if len(processed_audio) == 1: + input_features = processed_audio[0][np.newaxis, np.newaxis, :] + else: + input_features = np.stack(processed_audio)[:, np.newaxis, :] + else: + input_features = processed_audio[0] if len(processed_audio) == 1 else processed_audio + + outputs = { + "audio": input_features, # Use "audio" instead of "input_features" + } + + return outputs + + def _load_audio_from_path(self, audio_path: str) -> np.ndarray: + """ + Load audio from file path. + + Args: + audio_path (str): Path to audio file + + Returns: + np.ndarray: Loaded audio array + """ + # Get file extension to determine loading method + file_ext = os.path.splitext(audio_path)[1].lower() + + if file_ext in ['.wav', '.mp3', '.flac', '.m4a', '.ogg']: + # Audio file - use librosa + import librosa + audio_array, sr = librosa.load( + audio_path, + sr=self.sampling_rate, + mono=True + ) + return audio_array + elif file_ext == '.pt': + # PyTorch tensor file + audio_tensor = torch.load(audio_path, map_location='cpu').squeeze() + if isinstance(audio_tensor, torch.Tensor): + audio_array = audio_tensor.numpy() + else: + audio_array = np.array(audio_tensor) + return audio_array.astype(np.float32) + elif file_ext == '.npy': + # NumPy file + audio_array = np.load(audio_path) + return audio_array.astype(np.float32) + else: + raise ValueError( + f"Unsupported file format: {file_ext}. " + f"Supported formats: .wav, .mp3, .flac, .m4a, .ogg, .pt, .npy, .npz" + ) + + def preprocess_audio( + self, + audio_path_or_array: Union[str, np.ndarray], + normalize: Optional[bool] = None, + ) -> np.ndarray: + """ + Convenience method to preprocess audio from file path or array. + This method is kept for backward compatibility but __call__ is recommended. + + Args: + audio_path_or_array: Path to audio file or numpy array + normalize: Whether to normalize (overrides default setting) + + Returns: + np.ndarray: Preprocessed audio array + """ + if isinstance(audio_path_or_array, str): + audio_array = self._load_audio_from_path(audio_path_or_array) + else: + audio_array = np.array(audio_path_or_array, dtype=np.float32) + + # Override normalization setting if specified + original_normalize = self.normalize_audio + if normalize is not None: + self.normalize_audio = normalize + + try: + processed = self._process_single_audio(audio_array) + finally: + # Restore original setting + self.normalize_audio = original_normalize + + return processed + + # Override to_dict method for configuration saving + def to_dict(self) -> Dict[str, Any]: + """ + Convert the object to a dict containing all attributes needed for serialization. + """ + return self.feature_extractor_dict + + def save_audio( + self, + audio: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]], + output_path: str = "output.wav", + sampling_rate: Optional[int] = None, + normalize: bool = False, + batch_prefix: str = "audio_", + ): + """ + Save audio data to WAV file(s). + + Args: + audio: Audio data to save. Can be: + - torch.Tensor: PyTorch tensor with shape (B, C, T) or (B, T) or (T) + - np.ndarray: NumPy array with shape (B, C, T) or (B, T) or (T) + - List of tensors or arrays + output_path: Path where to save the audio. If saving multiple files, + this is treated as a directory and individual files will be saved inside. + sampling_rate: Sampling rate for the saved audio. Defaults to the processor's rate. + normalize: Whether to normalize audio before saving. + batch_prefix: Prefix for batch files when saving multiple audios. + + Returns: + List[str]: Paths to the saved audio files. + """ + if sampling_rate is None: + sampling_rate = self.sampling_rate + + try: + import soundfile as sf + except ImportError: + raise ImportError( + "soundfile is required to save audio files. " + "Install it with: pip install soundfile" + ) + + # Ensure audio is in the right format + if isinstance(audio, torch.Tensor): + # Convert PyTorch tensor to numpy + audio_np = audio.float().detach().cpu().numpy() + elif isinstance(audio, np.ndarray): + audio_np = audio + elif isinstance(audio, list): + # Handle list of tensors or arrays + if all(isinstance(a, torch.Tensor) for a in audio): + audio_np = [a.float().detach().cpu().numpy() for a in audio] + else: + audio_np = audio + else: + raise ValueError(f"Unsupported audio type: {type(audio)}") + + saved_paths = [] + + # Handle based on shape or type + if isinstance(audio_np, list): + # Multiple separate audios to save + output_dir = output_path + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Save each audio + for i, audio_item in enumerate(audio_np): + audio_item = self._prepare_audio_for_save(audio_item, normalize) + file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav") + sf.write(file_path, audio_item, sampling_rate) + saved_paths.append(file_path) + + else: + # Handle different dimensions + if len(audio_np.shape) >= 3: # (B, C, T) or similar + # Get batch size + batch_size = audio_np.shape[0] + + if batch_size > 1: + # Multiple audios in a batch + output_dir = output_path + + # Ensure output directory exists + os.makedirs(output_dir, exist_ok=True) + + # Save each audio in the batch + for i in range(batch_size): + # Extract single audio and remove channel dim if present + single_audio = audio_np[i] + if len(single_audio.shape) > 1: + if single_audio.shape[0] == 1: # (1, T) + single_audio = single_audio.squeeze(0) + + single_audio = self._prepare_audio_for_save(single_audio, normalize) + file_path = os.path.join(output_dir, f"{batch_prefix}{i}.wav") + sf.write(file_path, single_audio, sampling_rate) + saved_paths.append(file_path) + else: + # Single audio with batch and channel dims + audio_item = audio_np.squeeze() # Remove batch and channel dimensions + audio_item = self._prepare_audio_for_save(audio_item, normalize) + sf.write(output_path, audio_item, sampling_rate) + saved_paths.append(output_path) + else: + # Single audio without batch dimension + audio_item = self._prepare_audio_for_save(audio_np, normalize) + sf.write(output_path, audio_item, sampling_rate) + saved_paths.append(output_path) + + return saved_paths + + def _prepare_audio_for_save(self, audio: np.ndarray, normalize: bool) -> np.ndarray: + """ + Prepare audio for saving by ensuring it's the right shape and optionally normalizing. + + Args: + audio: Audio data as numpy array + normalize: Whether to normalize audio + + Returns: + np.ndarray: Processed audio ready for saving + """ + # Ensure right dimensionality + if len(audio.shape) > 1 and audio.shape[0] == 1: # (1, T) + audio = audio.squeeze(0) + + # Normalize if requested + if normalize: + max_val = np.abs(audio).max() + if max_val > 0: + audio = audio / max_val + + return audio + + +__all__ = ["VibeVoiceTokenizerProcessor", "AudioNormalizer"] \ No newline at end of file diff --git a/vibevoice/schedule/__init__.py b/vibevoice/schedule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vibevoice/schedule/dpm_solver.py b/vibevoice/schedule/dpm_solver.py new file mode 100644 index 0000000..806241f --- /dev/null +++ b/vibevoice/schedule/dpm_solver.py @@ -0,0 +1,1065 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import deprecate +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + # return math.cos(t * math.pi / 2 * 0.95) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + elif alpha_transform_type == "cauchy": + # µ + γ tan (π (0.5 - x)) γ = 1, µ = 3 + # alpha^2 = 1-1/(exp(λ)+1) + def alpha_bar_fn(t, gamma=1, mu=3): + snr = mu + gamma * math.tan(math.pi * (0.5 - t) * 0.9) + return 1 - 1 / (math.exp(snr) + 1.1) + + elif alpha_transform_type == "laplace": + # µ − bsgn(0.5 − t) log(1 − 2|t − 0.5|) µ = 0, b = 1 + def alpha_bar_fn(t, mu=0, b=1): + snr = mu - b * math.copysign(1, 0.5 - t) * math.log(1 - 2 * abs(t - 0.5) * 0.98) + return 1 - 1 / (math.exp(snr) + 1.02) + + else: + raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.Tensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.Tensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `DPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_lu_lambdas (`bool`, *optional*, defaults to `False`): + Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during + the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of + `lambda(t)`. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_lu_lambdas: Optional[bool] = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + rescale_betas_zero_snr: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message) + + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + elif beta_schedule == "squaredcos_cap_v2" or beta_schedule == "cosine": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cosine") + elif beta_schedule == "cauchy": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="cauchy") + elif beta_schedule == "laplace": + self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace") + else: + raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}") + + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + if rescale_betas_zero_snr: + # Close to 0 without being 0 so first sigma is not inf + # FP16 smallest positive subnormal works well here + self.alphas_cumprod[-1] = 2**-24 + + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++"] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + timesteps: Optional[List[int]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated + based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas` + must be `None`, and `timestep_spacing` attribute will be ignored. + """ + if num_inference_steps is None and timesteps is None: + raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.") + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + if timesteps is not None and self.config.use_karras_sigmas: + raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`") + if timesteps is not None and self.config.use_lu_lambdas: + raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`") + + if timesteps is not None: + timesteps = np.array(timesteps).astype(np.int64) + else: + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + + if self.config.use_karras_sigmas: + sigmas = np.flip(sigmas).copy() + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + elif self.config.use_lu_lambdas: + lambdas = np.flip(log_sigmas.copy()) + lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps) + sigmas = np.exp(lambdas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + else: + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = 1 / ((sigma**2 + 1) ** 0.5) + sigma_t = sigma * alpha_t + + return alpha_t, sigma_t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Lu et al. (2022).""" + + lambda_min: float = in_lambdas[-1].item() + lambda_max: float = in_lambdas[0].item() + + rho = 1.0 # 1.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = lambda_min ** (1 / rho) + max_inv_rho = lambda_max ** (1 / rho) + lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return lambdas + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError("missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop("timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError(" missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + self.sigmas[self.step_index - 2], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: int, + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32 + ) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to(device=model_output.device, dtype=torch.float32) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype) + # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype) + alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) + sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + alpha_t = alpha_t[timesteps].flatten() + while len(alpha_t.shape) < len(original_samples.shape): + alpha_t = alpha_t.unsqueeze(-1) + + sigma_t = sigma_t[timesteps].flatten() + while len(sigma_t.shape) < len(original_samples.shape): + sigma_t = sigma_t.unsqueeze(-1) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def get_velocity(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.IntTensor) -> torch.Tensor: + # alpha_t = self.alpha_t.to(device=original_samples.device, dtype=original_samples.dtype) + # sigma_t = self.sigma_t.to(device=original_samples.device, dtype=original_samples.dtype) + alpha_t = self.alpha_t.to(original_samples.device).to(original_samples.dtype) + sigma_t = self.sigma_t.to(original_samples.device).to(original_samples.dtype) + + timesteps = timesteps.to(original_samples.device) + alpha_t = alpha_t[timesteps].flatten() + while len(alpha_t.shape) < len(original_samples.shape): + alpha_t = alpha_t.unsqueeze(-1) + + sigma_t = sigma_t[timesteps].flatten() + while len(sigma_t.shape) < len(original_samples.shape): + sigma_t = sigma_t.unsqueeze(-1) + + velocity = alpha_t * noise - sigma_t * original_samples + return velocity + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/vibevoice/schedule/timestep_sampler.py b/vibevoice/schedule/timestep_sampler.py new file mode 100644 index 0000000..177b66f --- /dev/null +++ b/vibevoice/schedule/timestep_sampler.py @@ -0,0 +1,19 @@ +import math +import torch + + +class UniformSampler: + def __init__(self, timesteps = 1000): + self.timesteps = timesteps + def sample(self, batch_size, device): + return torch.randint(0, self.timesteps, (batch_size,), device=device) + +class LogitNormalSampler: + def __init__(self, timesteps = 1000, m = 0, s = 1): + self.timesteps = timesteps + timesteps = torch.linspace(0, 1, timesteps) + logit = torch.log(timesteps / (1 - timesteps)) + self.prob = torch.exp(-0.5 * (logit - m) ** 2 / s ** 2) / (s * math.sqrt(2 * math.pi)) + def sample(self, batch_size, device): + return torch.multinomial(self.prob, batch_size, replacement=True).to(device) + \ No newline at end of file diff --git a/vibevoice/scripts/__init__.py b/vibevoice/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py b/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py new file mode 100644 index 0000000..bb814cf --- /dev/null +++ b/vibevoice/scripts/convert_nnscaler_checkpoint_to_transformers.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# coding=utf-8 + +import argparse +import json +import os +from pathlib import Path +import re +import torch +from typing import Dict, List, Tuple + +from vibevoice.modular.configuration_vibevoice import ( + VibeVoiceConfig +) +from vibevoice.modular.modeling_vibevoice import VibeVoiceForConditionalGeneration +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +def convert_vibevoice_nnscaler_checkpoint_to_hf( + checkpoint_path: str, + pytorch_dump_folder_path: str, + config_path: str = None, +): + """ + Convert a nnscaler VibeVoice checkpoint to HuggingFace format. + Supports both regular checkpoints and tensor parallel checkpoints. + """ + + # Load regular checkpoint + logger.info(f"Loading regular checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") # ['model', 'optimizer', 'lr_scheduler', 'train_status', 'train_args', 'rng_states', 'nnscaler', 'dataloader'] + + # config = checkpoint['train_args'] + init_config_name = checkpoint['train_args']['vars']['model_args']['config_path']['relative_path'] + pretrained_name = checkpoint['train_args']['vars']['data_args']['tokenizer_path'] + + init_config_path = Path(__file__).parent.parent / 'configs' / init_config_name.split('/')[-1] + if init_config_path.exists(): + logger.info(f"Loading initial config from {init_config_path}") + with open(init_config_path, 'r') as f: + init_config = json.load(f) + else: + raise FileNotFoundError(f"Initial config file {init_config_path} not found. Please provide a valid path.") + + tie_word_embeddings = init_config['decoder_config'].get('tie_word_embeddings', True) + logger.info(f"Tie word embeddings: {tie_word_embeddings}") + + init_config['decoder_config']['use_cache'] = True + config = VibeVoiceConfig(**init_config, tie_word_embeddings=tie_word_embeddings) + + # # Extract the model state dict + model_state_dict = {k.replace('model.model.', 'model.'): v for k, v in checkpoint["model"].items() if k.startswith('model.model.')} + if not tie_word_embeddings and 'model.lm_head.weight' in checkpoint["model"].keys(): + # If not tying weights, we need to add the lm_head weight separately + model_state_dict['lm_head.weight'] = checkpoint["model"]['model.lm_head.weight'] + + # Override with provided config if available + if config_path: + logger.info(f"Loading config from {config_path}") + with open(config_path, 'r') as f: + config_dict = json.load(f) + config = VibeVoiceConfig.from_dict(config_dict) + + # Set the default dtype to bfloat16 before creating the model + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + + # Create the HuggingFace model + logger.info("Creating HuggingFace VibeVoiceForConditionalGeneration model") + model = VibeVoiceForConditionalGeneration(config) + + # Restore original dtype + torch.set_default_dtype(original_dtype) + + # Load the state dict + logger.info("Loading weights into model") + missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) + + if missing_keys: + logger.warning(f"Missing keys: {missing_keys}") + if unexpected_keys: + logger.warning(f"Unexpected keys: {unexpected_keys}") + + # Create output directory + os.makedirs(pytorch_dump_folder_path, exist_ok=True) + + # Save the model and config + logger.info(f"Saving model to {pytorch_dump_folder_path}") + + # Save config + config.save_pretrained(pytorch_dump_folder_path) + + # Save VibeVoiceProcessor configuration + logger.info("Saving VibeVoiceProcessor configuration") + processor_config = { + "processor_class": "VibeVoiceProcessor", + "speech_tok_compress_ratio": 3200, + "db_normalize": True, + # Audio processor configuration + "audio_processor": { + "feature_extractor_type": "VibeVoiceTokenizerProcessor", + "sampling_rate": 24000, + "normalize_audio": True, + "target_dB_FS": -25, + "eps": 1e-6, + }, + "language_model_pretrained_name": pretrained_name, + } + + processor_config_path = os.path.join(pytorch_dump_folder_path, "preprocessor_config.json") + with open(processor_config_path, 'w') as f: + json.dump(processor_config, f, indent=2) + logger.info(f"Saved processor config to {processor_config_path}") + + # Save model with sharding + # save_pretrained handles tied weights automatically + logger.info("Saving model weights with sharding...") + model.save_pretrained( + pytorch_dump_folder_path, + max_shard_size="2GB", # Set maximum size for each shard + safe_serialization=True # Ensure saving in .safetensors format + ) + logger.info(f"Model weights saved to {pytorch_dump_folder_path}") + + logger.info("Conversion complete!") + + # Verify the saved model can be loaded + logger.info("Verifying saved model...") + loaded_model = VibeVoiceForConditionalGeneration.from_pretrained(pytorch_dump_folder_path) + logger.info("Model successfully loaded from saved checkpoint!") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--nnscaler_checkpoint_path", + type=str, + required=True, + help="Path to the fairseq checkpoint (.pt file). For tensor parallel checkpoints, " + "provide any one of the part files (e.g., checkpoint_1_5000-model_part-0.pt), " + "and the script will automatically detect and merge all parts.", + ) + parser.add_argument( + "--pytorch_dump_folder_path", + type=str, + required=True, + help="Path to the output PyTorch model directory", + ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help="Optional path to a config JSON file to override extracted config", + ) + + args = parser.parse_args() + + convert_vibevoice_nnscaler_checkpoint_to_hf( + args.nnscaler_checkpoint_path, + args.pytorch_dump_folder_path, + args.config_path, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file