mirror of
https://github.com/snicolast/ComfyUI-IndexTTS2.git
synced 2026-04-30 03:31:34 +00:00
w2v bert model for 100% offline + fixed/unified HF fallback
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import librosa
|
||||
import json5
|
||||
@@ -85,7 +86,16 @@ class JsonHParams:
|
||||
|
||||
|
||||
def build_semantic_model(path_='./models/tts/maskgct/ckpt/wav2vec2bert_stats.pt'):
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
|
||||
local_dir_env = os.environ.get('W2V_BERT_LOCAL_DIR')
|
||||
stats_dir = os.path.dirname(path_) if path_ else None
|
||||
local_dir_default = os.path.join(stats_dir, 'w2v-bert-2.0') if stats_dir else None
|
||||
local_dir = local_dir_env or local_dir_default
|
||||
if local_dir and os.path.isdir(local_dir):
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained(local_dir, local_files_only=True)
|
||||
print('>> Wav2Vec2BertModel loaded from:', local_dir)
|
||||
else:
|
||||
semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
|
||||
print('>> Wav2Vec2BertModel loaded from HF repo: facebook/w2v-bert-2.0')
|
||||
semantic_model.eval()
|
||||
stat_mean_var = torch.load(path_)
|
||||
semantic_mean = stat_mean_var["mean"]
|
||||
|
||||
Reference in New Issue
Block a user