w2v bert model for 100% offline + fixed/unified HF fallback

This commit is contained in:
snicolast
2025-09-14 11:34:26 +12:00
parent 36d3abd27c
commit 016e29d103
3 changed files with 92 additions and 10 deletions

View File

@@ -113,7 +113,26 @@ class IndexTTS2:
print(">> Failed to load custom CUDA kernel for BigVGAN. Falling back to torch.")
self.use_cuda_kernel = False
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
#Prefer local w2v-bert-2.0 if present; otherwise allow HF download
try:
local_w2v_dir = os.path.join(self.model_dir, "w2v-bert-2.0")
if os.path.isdir(local_w2v_dir):
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
local_w2v_dir,
local_files_only=True,
)
print(">> W2V-BERT feature extractor loaded from:", local_w2v_dir)
else:
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
"facebook/w2v-bert-2.0"
)
print(">> W2V-BERT feature extractor loaded from HF repo: facebook/w2v-bert-2.0")
except Exception as e:
# Fallback to HF repo if local load fails unexpectedly
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(
"facebook/w2v-bert-2.0"
)
print(">> W2V-BERT feature extractor: local load failed; using HF repo. Error:", e)
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
os.path.join(self.model_dir, self.cfg.w2v_stat))
self.semantic_model = self.semantic_model.to(self.device)
@@ -123,10 +142,27 @@ class IndexTTS2:
semantic_codec = build_semantic_codec(self.cfg.semantic_codec)
semantic_code_ckpt = os.path.join(self.model_dir, "semantic_codec/model.safetensors")
safetensors.torch.load_model(semantic_codec, semantic_code_ckpt)
ckpt_to_load = semantic_code_ckpt
if not os.path.isfile(semantic_code_ckpt):
try:
# Attempt to download into HF cache and use from there
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
os.makedirs(hf_cache_dir, exist_ok=True)
ckpt_to_load = hf_hub_download(
repo_id="amphion/MaskGCT",
filename="semantic_codec/model.safetensors",
cache_dir=hf_cache_dir,
local_files_only=False,
)
print(">> semantic_codec weights downloaded to cache:", ckpt_to_load)
except Exception as e:
raise FileNotFoundError(
f"semantic_codec/model.safetensors not found and download failed: {e}"
)
safetensors.torch.load_model(semantic_codec, ckpt_to_load)
self.semantic_codec = semantic_codec.to(self.device)
self.semantic_codec.eval()
print('>> semantic_codec weights restored from: {}'.format(semantic_code_ckpt))
print('>> semantic_codec weights restored from: {}'.format(ckpt_to_load))
s2mel_path = os.path.join(self.model_dir, self.cfg.s2mel_checkpoint)
s2mel = MyModel(self.cfg.s2mel, use_gpt_latent=True)
@@ -143,20 +179,48 @@ class IndexTTS2:
self.s2mel.eval()
print(">> s2mel weights restored from:", s2mel_path)
# load campplus_model
# load campplus_model (local first; fallback to HF cache)
campplus_ckpt_path = os.path.join(self.model_dir, "campplus_cn_common.bin")
campplus_ckpt_to_load = campplus_ckpt_path
if not os.path.isfile(campplus_ckpt_path):
try:
hf_cache_dir = os.path.join(self.model_dir, 'hf_cache')
os.makedirs(hf_cache_dir, exist_ok=True)
campplus_ckpt_to_load = hf_hub_download(
repo_id="funasr/campplus",
filename="campplus_cn_common.bin",
cache_dir=hf_cache_dir,
local_files_only=False,
)
print(">> campplus_model weights downloaded to cache:", campplus_ckpt_to_load)
except Exception as e:
raise FileNotFoundError(
f"campplus_cn_common.bin not found and download failed: {e}"
)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.load_state_dict(torch.load(campplus_ckpt_to_load, map_location="cpu"))
self.campplus_model = campplus_model.to(self.device)
self.campplus_model.eval()
print(">> campplus_model weights restored from:", campplus_ckpt_path)
print(">> campplus_model weights restored from:", campplus_ckpt_to_load)
bigvgan_name = os.path.join(self.model_dir, self.cfg.vocoder.name)
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=self.use_cuda_kernel)
# BigVGAN: prefer local dir if complete; otherwise fallback to HF repo
bigvgan_local_dir = os.path.join(self.model_dir, self.cfg.vocoder.name)
bigvgan_config = os.path.join(bigvgan_local_dir, "config.json")
bigvgan_weights = os.path.join(bigvgan_local_dir, "bigvgan_generator.pt")
bigvgan_source = None
if os.path.isdir(bigvgan_local_dir) and os.path.isfile(bigvgan_config) and os.path.isfile(bigvgan_weights):
self.bigvgan = bigvgan.BigVGAN.from_pretrained(bigvgan_local_dir, use_cuda_kernel=self.use_cuda_kernel)
bigvgan_source = bigvgan_local_dir
else:
# fallback to HF repo id (cached under HF_HUB_CACHE)
repo_id = "nvidia/bigvgan_v2_22khz_80band_256x"
print(">> BigVGAN local files missing or incomplete; loading from HF repo:", repo_id)
self.bigvgan = bigvgan.BigVGAN.from_pretrained(repo_id, use_cuda_kernel=self.use_cuda_kernel)
bigvgan_source = repo_id
self.bigvgan = self.bigvgan.to(self.device)
self.bigvgan.remove_weight_norm()
self.bigvgan.eval()
print(">> bigvgan weights restored from:", bigvgan_name)
print(">> bigvgan weights restored from:", bigvgan_source)
self.bpe_path = os.path.join(self.model_dir, self.cfg.dataset["bpe_model"])
self.normalizer = TextNormalizer()

View File

@@ -1,3 +1,4 @@
import os
import torch
import librosa
import json5
@@ -85,7 +86,16 @@ class JsonHParams:
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'):
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
local_dir_env = os.environ.get('W2V_BERT_LOCAL_DIR')
stats_dir = os.path.dirname(path_) if path_ else None
local_dir_default = os.path.join(stats_dir, 'w2v-bert-2.0') if stats_dir else None
local_dir = local_dir_env or local_dir_default
if local_dir and os.path.isdir(local_dir):
semantic_model = Wav2Vec2BertModel.from_pretrained(local_dir, local_files_only=True)
print('>> Wav2Vec2BertModel loaded from:', local_dir)
else:
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
print('>> Wav2Vec2BertModel loaded from HF repo: facebook/w2v-bert-2.0')
semantic_model.eval()
stat_mean_var = torch.load(path_)
semantic_mean = stat_mean_var["mean"]