Restore local SeamlessM4T extractor fallback - Simplified / better readme

This commit is contained in:
snicolast
2025-10-05 12:27:05 +13:00
parent 586bd77efe
commit b2418eeff0
2 changed files with 61 additions and 82 deletions

View File

@@ -134,7 +134,26 @@ class IndexTTS2:
print(f"{e!r}")
self.use_cuda_kernel = False
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
local_w2v_dir_env = os.environ.get("W2V_BERT_LOCAL_DIR")
local_w2v_dir = local_w2v_dir_env or os.path.join(self.model_dir, "w2v-bert-2.0")
if local_w2v_dir and os.path.isdir(local_w2v_dir):
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained(local_w2v_dir, local_files_only=True)
print(f">> SeamlessM4TFeatureExtractor loaded from: {local_w2v_dir}")
else:
fe_kwargs = {}
if HF_AUTH_TOKEN:
fe_kwargs["token"] = HF_AUTH_TOKEN
try:
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", **fe_kwargs)
print(">> SeamlessM4TFeatureExtractor loaded from HF repo: facebook/w2v-bert-2.0")
except HfHubHTTPError as err:
status = getattr(err.response, "status_code", None)
if status == 401 and HF_AUTH_TOKEN:
print(f"[IndexTTS2] Feature extractor download failed with 401; retrying anonymously.")
self.extract_features = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", token=False)
print(">> SeamlessM4TFeatureExtractor loaded from HF repo: facebook/w2v-bert-2.0")
else:
raise
self.semantic_model, self.semantic_mean, self.semantic_std = build_semantic_model(
os.path.join(self.model_dir, self.cfg.w2v_stat))
self.semantic_model = self.semantic_model.to(self.device)