refactor THA3 model installer

This commit is contained in:
Juha Jeronen
2024-01-08 15:45:44 +02:00
parent 32cda5786c
commit bd29500c9e
3 changed files with 39 additions and 39 deletions

View File

@@ -183,25 +183,8 @@ if not torch.cuda.is_available() and not args.cpu:
print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
if "talkinghead" in modules:
# Install the THA3 models if needed
talkinghead_models_dir = os.path.join(os.getcwd(), "talkinghead", "tha3", "models")
if not os.path.exists(talkinghead_models_dir):
# API:
# https://huggingface.co/docs/huggingface_hub/en/guides/download
try:
from huggingface_hub import snapshot_download
except ImportError:
raise ImportError(
"You need to install huggingface_hub to install talkinghead models automatically. "
"See https://pypi.org/project/huggingface-hub/ for installation."
)
os.makedirs(talkinghead_models_dir, exist_ok=True)
print(f"THA3 models not yet installed. Installing from {args.talkinghead_models} into talkinghead/tha3/models.")
# Installing with symlinks would be generally better, but MS Windows support for symlinks is not optimal,
# so for maximal compatibility we avoid them. The drawback of installing directly as plain files is that
# if multiple programs need to download THA3, they will do so separately. But THA3 is rather rare, so in
# practice this is unlikely to be an issue.
snapshot_download(repo_id=args.talkinghead_models, local_dir=talkinghead_models_dir, local_dir_use_symlinks=False)
talkinghead_path = os.path.abspath(os.path.join(os.getcwd(), "talkinghead"))
sys.path.append(talkinghead_path) # Add the path to the 'tha3' module to the sys.path list
import sys
import threading
@@ -211,10 +194,14 @@ if "talkinghead" in modules:
# FP16 boosts the rendering performance by ~1.5x, but is only supported on GPU.
model = "separable_half" if args.talkinghead_gpu else "separable_float"
print(f"Initializing talkinghead pipeline in {mode} mode with model {model}....")
talkinghead_path = os.path.abspath(os.path.join(os.getcwd(), "talkinghead"))
sys.path.append(talkinghead_path) # Add the path to the 'tha3' module to the sys.path list
try:
from talkinghead.tha3.app.util import maybe_install_models as talkinghead_maybe_install_models
# Install the THA3 models if needed
talkinghead_models_dir = os.path.join(os.getcwd(), "talkinghead", "tha3", "models")
talkinghead_maybe_install_models(hf_reponame=args.talkinghead_models, modelsdir=talkinghead_models_dir)
import talkinghead.tha3.app.app as talkinghead
def launch_talkinghead():
# mode: choices='The device to use for PyTorch ("cuda" for GPU, "cpu" for CPU).'