mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-01-26 14:39:44 +00:00
w2v bert model for 100% offline + fixed/unified HF fallback
This commit is contained in:
@@ -113,7 +113,26 @@ class IndexTTS2:
|
||||
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
|
||||
self.use_cuda_kernel = False
|
||||
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
|
||||
#Prefer local w2v-bert-2.0 if present; otherwise allow HF download
|
||||
try:
|
||||
local_w2v_dir = os.path.join(self.model_dir, "w2v-bert-2.0")
|
||||
if os.path.isdir(local_w2v_dir):
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
local_w2v_dir,
|
||||
local_files_only=True,
|
||||
)
|
||||
print(">> W2V-BERT feature extractor loaded from:", local_w2v_dir)
|
||||
else:
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
"facebook/w2v-bert-2.0"
|
||||
)
|
||||
print(">> W2V-BERT feature extractor loaded from HF repo: facebook/w2v-bert-2.0")
|
||||
except Exception as e:
|
||||
# Fallback to HF repo if local load fails unexpectedly
|
||||
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
|
||||
"facebook/w2v-bert-2.0"
|
||||
)
|
||||
print(">> W2V-BERT feature extractor: local load failed; using HF repo. Error:", e)
|
||||
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
|
||||
os.path.join(self.model_dir, self.cfg.w2v_stat))
|
||||
self.semantic_model = self.semantic_model.to(self.device)
|
||||
@@ -123,10 +142,27 @@ class IndexTTS2:
|
||||
|
||||
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
|
||||
semantic_code_ckpt = os.path.join(self.model_dir, "semantic_codec/model.safetensors")
|
||||
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
|
||||
ckpt_to_load = semantic_code_ckpt
|
||||
if not os.path.isfile(semantic_code_ckpt):
|
||||
try:
|
||||
# Attempt to download into HF cache and use from there
|
||||
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
|
||||
os.makedirs(hf_cache_dir, exist_ok=True)
|
||||
ckpt_to_load = hf_hub_download(
|
||||
repo_id="amphion/MaskGCT",
|
||||
filename="semantic_codec/model.safetensors",
|
||||
cache_dir=hf_cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
print(">> semantic_codec weights downloaded to cache:", ckpt_to_load)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"semantic_codec/model.safetensors not found and download failed: {e}"
|
||||
)
|
||||
safetensors.torch.load_model(semantic_codec, ckpt_to_load)
|
||||
self.semantic_codec = semantic_codec.to(self.device)
|
||||
self.semantic_codec.eval()
|
||||
print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
|
||||
print('>> semantic_codec weights restored from: {}'.format(ckpt_to_load))
|
||||
|
||||
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
|
||||
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
|
||||
@@ -143,20 +179,48 @@ class IndexTTS2:
|
||||
self.s2mel.eval()
|
||||
print(">> s2mel weights restored from:", s2mel_path)
|
||||
|
||||
# load campplus_model
|
||||
# load campplus_model (local first; fallback to HF cache)
|
||||
campplus_ckpt_path = os.path.join(self.model_dir, "campplus_cn_common.bin")
|
||||
campplus_ckpt_to_load = campplus_ckpt_path
|
||||
if not os.path.isfile(campplus_ckpt_path):
|
||||
try:
|
||||
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
|
||||
os.makedirs(hf_cache_dir, exist_ok=True)
|
||||
campplus_ckpt_to_load = hf_hub_download(
|
||||
repo_id="funasr/campplus",
|
||||
filename="campplus_cn_common.bin",
|
||||
cache_dir=hf_cache_dir,
|
||||
local_files_only=False,
|
||||
)
|
||||
print(">> campplus_model weights downloaded to cache:", campplus_ckpt_to_load)
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
f"campplus_cn_common.bin not found and download failed: {e}"
|
||||
)
|
||||
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
|
||||
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
|
||||
campplus_model.load_state_dict(torch.load(campplus_ckpt_to_load, map_location="cpu"))
|
||||
self.campplus_model = campplus_model.to(self.device)
|
||||
self.campplus_model.eval()
|
||||
print(">> campplus_model weights restored from:", campplus_ckpt_path)
|
||||
print(">> campplus_model weights restored from:", campplus_ckpt_to_load)
|
||||
|
||||
bigvgan_name = os.path.join(self.model_dir, self.cfg.vocoder.name)
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=self.use_cuda_kernel)
|
||||
# BigVGAN: prefer local dir if complete; otherwise fallback to HF repo
|
||||
bigvgan_local_dir = os.path.join(self.model_dir, self.cfg.vocoder.name)
|
||||
bigvgan_config = os.path.join(bigvgan_local_dir, "config.json")
|
||||
bigvgan_weights = os.path.join(bigvgan_local_dir, "bigvgan_generator.pt")
|
||||
bigvgan_source = None
|
||||
if os.path.isdir(bigvgan_local_dir) and os.path.isfile(bigvgan_config) and os.path.isfile(bigvgan_weights):
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_local_dir, use_cuda_kernel=self.use_cuda_kernel)
|
||||
bigvgan_source = bigvgan_local_dir
|
||||
else:
|
||||
# fallback to HF repo id (cached under HF_HUB_CACHE)
|
||||
repo_id = "nvidia/bigvgan_v2_22khz_80band_256x"
|
||||
print(">> BigVGAN local files missing or incomplete; loading from HF repo:", repo_id)
|
||||
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel)
|
||||
bigvgan_source = repo_id
|
||||
self.bigvgan = self.bigvgan.to(self.device)
|
||||
self.bigvgan.remove_weight_norm()
|
||||
self.bigvgan.eval()
|
||||
print(">> bigvgan weights restored from:", bigvgan_name)
|
||||
print(">> bigvgan weights restored from:", bigvgan_source)
|
||||
|
||||
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
|
||||
self.normalizer = TextNormalizer()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import librosa
|
||||
import json5
|
||||
@@ -85,7 +86,16 @@ class JsonHParams:
|
||||
|
||||
|
||||
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'):
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
|
||||
local_dir_env = os.environ.get('W2V_BERT_LOCAL_DIR')
|
||||
stats_dir = os.path.dirname(path_) if path_ else None
|
||||
local_dir_default = os.path.join(stats_dir, 'w2v-bert-2.0') if stats_dir else None
|
||||
local_dir = local_dir_env or local_dir_default
|
||||
if local_dir and os.path.isdir(local_dir):
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained(local_dir, local_files_only=True)
|
||||
print('>> Wav2Vec2BertModel loaded from:', local_dir)
|
||||
else:
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
|
||||
print('>> Wav2Vec2BertModel loaded from HF repo: facebook/w2v-bert-2.0')
|
||||
semantic_model.eval()
|
||||
stat_mean_var = torch.load(path_)
|
||||
semantic_mean = stat_mean_var["mean"]
|
||||
|
||||
Reference in New Issue
Block a user