w2v bert model for 100% offline + fixed/unified HF fallback

This commit is contained in:
snicolast
2025-09-14 11:34:26 +12:00
parent 36d3abd27c
commit 016e29d103
3 changed files with 92 additions and 10 deletions

View File

@@ -1,3 +1,4 @@
import os
import torch
import librosa
import json5
@@ -85,7 +86,16 @@ class JsonHParams:
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'):
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
local_dir_env = os.environ.get('W2V_BERT_LOCAL_DIR')
stats_dir = os.path.dirname(path_) if path_ else None
local_dir_default = os.path.join(stats_dir, 'w2v-bert-2.0') if stats_dir else None
local_dir = local_dir_env or local_dir_default
if local_dir and os.path.isdir(local_dir):
semantic_model = Wav2Vec2BertModel.from_pretrained(local_dir, local_files_only=True)
print('>> Wav2Vec2BertModel loaded from:', local_dir)
else:
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
print('>> Wav2Vec2BertModel loaded from HF repo: facebook/w2v-bert-2.0')
semantic_model.eval()
stat_mean_var = torch.load(path_)
semantic_mean = stat_mean_var["mean"]