talkinghead: fix and improve THA3 manual poser (#204)

* talkinghead: fix and improve THA3 manual poser

* server.py: no, don't yet use fp16 for talkinghead
This commit is contained in:
Juha Jeronen
2023-12-20 01:21:24 +02:00
committed by GitHub
parent c5d1773f6d
commit 7ca92eaeac
10 changed files with 1159 additions and 560 deletions

View File

@@ -0,0 +1,14 @@
#!/bin/bash
#
# Launch the THA3 manual poser app.
#
# This app can be used to generate static expression images, given just
# one static input image in the appropriate format.
#
# This app is standalone, and does not interact with SillyTavern.
#
# This must run in the "extras" conda venv!
# Do this first:
# conda activate extras
#
python -m tha3.app.manual_poser $@

View File

@@ -0,0 +1,20 @@
#!/bin/bash
#
# Launch THA3 in standalone app mode.
#
# This standalone app mode does not interact with SillyTavern.
#
# The usual way to run this fork of THA3 is as a SillyTavern-extras plugin.
# The standalone app mode comes from the original THA3 code, and is included
# for testing and debugging.
#
# If you want to manually pose a character (to generate static expression images),
# use `start_manual_poser.sh` instead.
#
# This must run in the "extras" conda venv!
# Do this first:
# conda activate extras
#
# The `--char=...` flag can be used to specify which image to load under "tha3/images".
#
python -m tha3.app.app --char=example.png $@

View File

@@ -1,3 +1,5 @@
# TODO: Standalone app mode does not work yet. The SillyTavern-extras plugin mode works.
import argparse
import ast
import os
@@ -7,19 +9,18 @@ import threading
import time
import torch
import io
import torch.nn.functional as F
import wx
import numpy as np
import json
import typing
from PIL import Image
from torchvision import transforms
from flask import Flask, Response
from flask_cors import CORS
from io import BytesIO
sys.path.append(os.getcwd())
from tha3.mocap.ifacialmocap_constants import *
from tha3.mocap import ifacialmocap_constants as mocap_constants
from tha3.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose
from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter
from tha3.mocap.ifacialmocap_poser_converter_25 import create_ifacialmocap_pose_converter
@@ -31,7 +32,7 @@ from tha3.util import (
)
from typing import Optional
# Global Variables
# Global variables
global_source_image = None
global_result_image = None
global_reload = None
@@ -43,7 +44,7 @@ lasttranisitiondPose = "NotInit"
inMotion = False
fps = 0
current_pose = None
storepath = os.path.join(os.getcwd(), "talkinghead", "emotions")
global_basedir = "talkinghead" # for SillyTavern-extras live mode; if running standalone, we override this later
# Flask setup
app = Flask(__name__)
@@ -60,7 +61,7 @@ def setEmotion(_emotion):
highest_score = item['score']
highest_label = item['label']
#print("Applying ", emotion)
# print("Applying ", emotion)
emotion = highest_label
def unload():
@@ -85,10 +86,10 @@ def result_feed():
try:
rgb_image = global_result_image[:, :, [2, 1, 0]] # Swap B and R channels
pil_image = Image.fromarray(np.uint8(rgb_image)) # Convert to PIL Image
if global_result_image.shape[2] == 4: # Check if there is an alpha channel present
alpha_channel = global_result_image[:, :, 3] # Extract alpha channel
pil_image.putalpha(Image.fromarray(np.uint8(alpha_channel))) # Set alpha channel in the PIL Image
buffer = io.BytesIO() # Save as PNG with RGBA mode
if global_result_image.shape[2] == 4: # Check if there is an alpha channel present
alpha_channel = global_result_image[:, :, 3] # Extract alpha channel
pil_image.putalpha(Image.fromarray(np.uint8(alpha_channel))) # Set alpha channel in the PIL Image
buffer = io.BytesIO() # Save as PNG with RGBA mode
pil_image.save(buffer, format='PNG')
image_bytes = buffer.getvalue()
except Exception as e:
@@ -100,19 +101,20 @@ def result_feed():
return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame')
def talkinghead_load_file(stream):
global global_basedir
global global_source_image
global global_reload
global global_timer_paused
global_timer_paused = False
try:
pil_image = Image.open(stream) # Load the image using PIL.Image.open
img_data = BytesIO() # Create a copy of the image data in memory using BytesIO
pil_image = Image.open(stream) # Load the image using PIL.Image.open
img_data = BytesIO() # Create a copy of the image data in memory using BytesIO
pil_image.save(img_data, format='PNG')
global_reload = Image.open(BytesIO(img_data.getvalue())) # Set the global_reload to the copy of the image data
global_reload = Image.open(BytesIO(img_data.getvalue())) # Set the global_reload to the copy of the image data
except Image.UnidentifiedImageError:
print(f"Could not load image from file, loading blank")
full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.sep.join(["talkinghead", "tha3", "images", "inital.png"])))
print("Could not load image from file, loading blank")
full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.join(global_basedir, "tha3", "images", "inital.png")))
MainFrame.load_image(None, full_path)
global_timer_paused = True
return 'OK'
@@ -121,32 +123,41 @@ def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor:
rgb_image = torch_linear_to_srgb(image[0:3, :, :])
return torch.cat([rgb_image, image[3:4, :, :]], dim=0)
def launch_gui(device, model):
def launch_gui(device: str, model: str, char: typing.Optional[str] = None, standalone: bool = False):
"""
device: "cpu" or "cuda"
model: one of the folder names inside "talkinghead/tha3/models/"
char: name of png file inside "talkinghead/tha3/images/"; if not given, defaults to "inital.png".
"""
global global_basedir
global initAMI
initAMI = True
# TODO: We could use this to parse the arguments that were provided to `server.py`, but we don't currently use the parser output.
parser = argparse.ArgumentParser(description='uWu Waifu')
# Add other parser arguments here
args, unknown = parser.parse_known_args()
try:
poser = load_poser(model, device)
pose_converter = create_ifacialmocap_pose_converter() #creates a list of 45
if char is None:
char = "inital.png"
app = wx.App(redirect=False)
try:
poser = load_poser(model, device, modelsdir=os.path.join(global_basedir, "tha3", "models"))
pose_converter = create_ifacialmocap_pose_converter() # creates a list of 45
app = wx.App()
main_frame = MainFrame(poser, pose_converter, device)
main_frame.SetSize((750, 600))
#Lload default image (you can pass args.char if required)
full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.sep.join(["talkinghead", "tha3", "images", "inital.png"])))
# Load character image
full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.join(global_basedir, "tha3", "images", char)))
main_frame.load_image(None, full_path)
#main_frame.Show(True)
if standalone:
main_frame.Show(True)
main_frame.capture_timer.Start(100)
main_frame.animation_timer.Start(100)
wx.DisableAsserts() #prevent popup about debug alert closed from other threads
wx.DisableAsserts() # prevent popup about debug alert closed from other threads
app.MainLoop()
except RuntimeError as e:
@@ -235,7 +246,7 @@ class MainFrame(wx.Frame):
current_pose = self.ifacialmocap_pose
# NOTE: randomize mouth
for blendshape_name in BLENDSHAPE_NAMES:
for blendshape_name in mocap_constants.BLENDSHAPE_NAMES:
if "jawOpen" in blendshape_name:
if is_talking or is_talking_override:
current_pose[blendshape_name] = self.random_generate_value(-5000, 5000, abs(1 - current_pose[blendshape_name]))
@@ -247,7 +258,7 @@ class MainFrame(wx.Frame):
def animationHeadMove(self):
current_pose = self.ifacialmocap_pose
for key in [HEAD_BONE_Y]: #can add more to this list if needed
for key in [mocap_constants.HEAD_BONE_Y]: # can add more to this list if needed
current_pose[key] = self.random_generate_value(-20, 20, current_pose[key])
return current_pose
@@ -265,29 +276,30 @@ class MainFrame(wx.Frame):
return current_pose
def addNamestoConvert(pose):
# TODO: What are the unknown keys?
index_to_name = {
0: 'eyebrow_troubled_left_index', #COMBACK TO UNK
1: 'eyebrow_troubled_right_index',#COMBACK TO UNK
0: 'eyebrow_troubled_left_index',
1: 'eyebrow_troubled_right_index',
2: 'eyebrow_angry_left_index',
3: 'eyebrow_angry_right_index',
4: 'unknown1', #COMBACK TO UNK
5: 'unknown2', #COMBACK TO UNK
4: 'unknown1', # COMBACK TO UNK
5: 'unknown2', # COMBACK TO UNK
6: 'eyebrow_raised_left_index',
7: 'eyebrow_raised_right_index',
8: 'eyebrow_happy_left_index',
9: 'eyebrow_happy_right_index',
10: 'unknown3', #COMBACK TO UNK
11: 'unknown4', #COMBACK TO UNK
10: 'unknown3', # COMBACK TO UNK
11: 'unknown4', # COMBACK TO UNK
12: 'wink_left_index',
13: 'wink_right_index',
14: 'eye_happy_wink_left_index',
15: 'eye_happy_wink_right_index',
16: 'eye_surprised_left_index',
17: 'eye_surprised_right_index',
18: 'unknown5', #COMBACK TO UNK
19: 'unknown6', #COMBACK TO UNK
20: 'unknown7', #COMBACK TO UNK
21: 'unknown8', #COMBACK TO UNK
18: 'unknown5', # COMBACK TO UNK
19: 'unknown6', # COMBACK TO UNK
20: 'unknown7', # COMBACK TO UNK
21: 'unknown8', # COMBACK TO UNK
22: 'eye_raised_lower_eyelid_left_index',
23: 'eye_raised_lower_eyelid_right_index',
24: 'iris_small_left_index',
@@ -295,14 +307,14 @@ class MainFrame(wx.Frame):
26: 'mouth_aaa_index',
27: 'mouth_iii_index',
28: 'mouth_ooo_index',
29: 'unknown9a', #COMBACK TO UNK
29: 'unknown9a', # COMBACK TO UNK
30: 'mouth_ooo_index2',
31: 'unknown9', #COMBACK TO UNK
32: 'unknown10', #COMBACK TO UNK
33: 'unknown11', #COMBACK TO UNK
31: 'unknown9', # COMBACK TO UNK
32: 'unknown10', # COMBACK TO UNK
33: 'unknown11', # COMBACK TO UNK
34: 'mouth_raised_corner_left_index',
35: 'mouth_raised_corner_right_index',
36: 'unknown12',
36: 'unknown12', # COMBACK TO UNK
37: 'iris_rotation_x_index',
38: 'iris_rotation_y_index',
39: 'head_x_index',
@@ -321,17 +333,16 @@ class MainFrame(wx.Frame):
return output
def get_emotion_values(self, emotion): # Place to define emotion presets
global storepath
def get_emotion_values(self, emotion): # Place to define emotion presets
global global_basedir
#print(emotion)
file_path = os.path.join(storepath, emotion + ".json")
#print("trying: ", file_path)
# print(emotion)
file_path = os.path.join(global_basedir, "emotions", emotion + ".json")
# print("trying: ", file_path)
if not os.path.exists(file_path):
print("using backup for: ", file_path)
file_path = os.path.join(storepath, "_defaults.json")
print("using backup for: ", file_path)
file_path = os.path.join(global_basedir, "emotions", "_defaults.json")
with open(file_path, 'r') as json_file:
emotions = json.load(json_file)
@@ -339,8 +350,8 @@ class MainFrame(wx.Frame):
targetpose = emotions.get(emotion, {})
targetpose_values = targetpose
#targetpose_values = list(targetpose.values())
#print("targetpose: ", targetpose, "for ", emotion)
# targetpose_values = list(targetpose.values())
# print("targetpose: ", targetpose, "for ", emotion)
return targetpose_values
def animateToEmotion(self, current_pose_list, target_pose_dict):
@@ -362,9 +373,9 @@ class MainFrame(wx.Frame):
return transitionPose
def animationMain(self):
self.ifacialmocap_pose = self.animationBlink()
self.ifacialmocap_pose = self.animationHeadMove()
self.ifacialmocap_pose = self.animationTalking()
self.ifacialmocap_pose = self.animationBlink()
self.ifacialmocap_pose = self.animationHeadMove()
self.ifacialmocap_pose = self.animationTalking()
return self.ifacialmocap_pose
def filter_by_index(self, current_pose_list, index):
@@ -414,7 +425,6 @@ class MainFrame(wx.Frame):
self.animation_left_panel_sizer.Fit(self.animation_left_panel)
# Right Column (Sliders)
self.animation_right_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER)
self.animation_right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
self.animation_right_panel.SetSizer(self.animation_right_panel_sizer)
@@ -442,15 +452,13 @@ class MainFrame(wx.Frame):
self.output_background_choice.SetSelection(0)
self.animation_right_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND)
# These are applied to `ifacialmocap_pose`, so we can only use names that are defined there (see `update_ifacialmocap_pose`).
blendshape_groups = {
'Eyes': ['eyeLookOutLeft', 'eyeLookOutRight', 'eyeLookDownLeft', 'eyeLookUpLeft', 'eyeWideLeft', 'eyeWideRight'],
'Mouth': ['mouthFrownLeft'],
'Mouth': ['mouthSmileLeft', 'mouthFrownLeft'],
'Cheek': ['cheekSquintLeft', 'cheekSquintRight', 'cheekPuff'],
'Brow': ['browDownLeft', 'browOuterUpLeft', 'browDownRight', 'browOuterUpRight', 'browInnerUp'],
'Eyelash': ['mouthSmileLeft'],
# 'Eyelash': [],
'Nose': ['noseSneerLeft', 'noseSneerRight'],
'Misc': ['tongueOut']
}
@@ -492,26 +500,26 @@ class MainFrame(wx.Frame):
def on_slider_change(self, event):
slider = event.GetEventObject()
value = slider.GetValue() / 100.0 # Divide by 100 to get the actual float value
#print(value)
# print(value)
slider_name = slider.GetName()
self.ifacialmocap_pose[slider_name] = value
def create_ui(self):
#MAke the UI Elements
# Make the UI Elements
self.main_sizer = wx.BoxSizer(wx.VERTICAL)
self.SetSizer(self.main_sizer)
self.SetAutoLayout(1)
self.capture_pose_lock = threading.Lock()
#Main panel with JPS
# Main panel with JPS
self.create_animation_panel(self)
self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5))
def update_capture_panel(self, event: wx.Event):
data = self.ifacialmocap_pose
for rotation_name in ROTATION_NAMES:
value = data[rotation_name]
for rotation_name in mocap_constants.ROTATION_NAMES:
value = data[rotation_name] # TODO/FIXME: updating unused variable; what was this supposed to do?
@staticmethod
def convert_to_100(x):
@@ -561,7 +569,7 @@ class MainFrame(wx.Frame):
pose_dict = dict(zip(pose_names, combine_pose))
return pose_dict
def determine_data_type(self, data):
def determine_data_type(self, data): # TODO: is this needed, nothing in the project seems to call it; and why not just use `isinstance` directly?
if isinstance(data, list):
print("It's a list.")
elif isinstance(data, dict):
@@ -600,8 +608,8 @@ class MainFrame(wx.Frame):
else:
raise ValueError("Unsupported data type passed to dict_to_tensor.")
def update_ifacualmocap_pose(self, ifacualmocap_pose, emotion_pose):
# Update Values - The following values are in emotion_pose but not defined in ifacualmocap_pose
def update_ifacialmocap_pose(self, ifacialmocap_pose, emotion_pose):
# Update Values - The following values are in emotion_pose but not defined in ifacialmocap_pose
# eye_happy_wink_left_index, eye_happy_wink_right_index
# eye_surprised_left_index, eye_surprised_right_index
# eye_relaxed_left_index, eye_relaxed_right_index
@@ -616,56 +624,55 @@ class MainFrame(wx.Frame):
# body_z_index
# breathing_index
ifacualmocap_pose['browDownLeft'] = emotion_pose['eyebrow_troubled_left_index']
ifacualmocap_pose['browDownRight'] = emotion_pose['eyebrow_troubled_right_index']
ifacualmocap_pose['browOuterUpLeft'] = emotion_pose['eyebrow_angry_left_index']
ifacualmocap_pose['browOuterUpRight'] = emotion_pose['eyebrow_angry_right_index']
ifacualmocap_pose['browInnerUp'] = emotion_pose['eyebrow_happy_left_index']
ifacualmocap_pose['browInnerUp'] += emotion_pose['eyebrow_happy_right_index']
ifacualmocap_pose['browDownLeft'] = emotion_pose['eyebrow_raised_left_index']
ifacualmocap_pose['browDownRight'] = emotion_pose['eyebrow_raised_right_index']
ifacualmocap_pose['browDownLeft'] += emotion_pose['eyebrow_lowered_left_index']
ifacualmocap_pose['browDownRight'] += emotion_pose['eyebrow_lowered_right_index']
ifacualmocap_pose['browDownLeft'] += emotion_pose['eyebrow_serious_left_index']
ifacualmocap_pose['browDownRight'] += emotion_pose['eyebrow_serious_right_index']
ifacialmocap_pose['browDownLeft'] = emotion_pose['eyebrow_troubled_left_index']
ifacialmocap_pose['browDownRight'] = emotion_pose['eyebrow_troubled_right_index']
ifacialmocap_pose['browOuterUpLeft'] = emotion_pose['eyebrow_angry_left_index']
ifacialmocap_pose['browOuterUpRight'] = emotion_pose['eyebrow_angry_right_index']
ifacialmocap_pose['browInnerUp'] = emotion_pose['eyebrow_happy_left_index']
ifacialmocap_pose['browInnerUp'] += emotion_pose['eyebrow_happy_right_index']
ifacialmocap_pose['browDownLeft'] = emotion_pose['eyebrow_raised_left_index']
ifacialmocap_pose['browDownRight'] = emotion_pose['eyebrow_raised_right_index']
ifacialmocap_pose['browDownLeft'] += emotion_pose['eyebrow_lowered_left_index']
ifacialmocap_pose['browDownRight'] += emotion_pose['eyebrow_lowered_right_index']
ifacialmocap_pose['browDownLeft'] += emotion_pose['eyebrow_serious_left_index']
ifacialmocap_pose['browDownRight'] += emotion_pose['eyebrow_serious_right_index']
# Update eye values
ifacualmocap_pose['eyeWideLeft'] = emotion_pose['eye_surprised_left_index']
ifacualmocap_pose['eyeWideRight'] = emotion_pose['eye_surprised_right_index']
ifacialmocap_pose['eyeWideLeft'] = emotion_pose['eye_surprised_left_index']
ifacialmocap_pose['eyeWideRight'] = emotion_pose['eye_surprised_right_index']
# Update eye blink (though we will overwrite it later)
ifacualmocap_pose['eyeBlinkLeft'] = emotion_pose['eye_wink_left_index']
ifacualmocap_pose['eyeBlinkRight'] = emotion_pose['eye_wink_right_index']
ifacialmocap_pose['eyeBlinkLeft'] = emotion_pose['eye_wink_left_index']
ifacialmocap_pose['eyeBlinkRight'] = emotion_pose['eye_wink_right_index']
# Update iris rotation values
ifacualmocap_pose['eyeLookInLeft'] = -emotion_pose['iris_rotation_y_index']
ifacualmocap_pose['eyeLookOutLeft'] = emotion_pose['iris_rotation_y_index']
ifacualmocap_pose['eyeLookInRight'] = emotion_pose['iris_rotation_y_index']
ifacualmocap_pose['eyeLookOutRight'] = -emotion_pose['iris_rotation_y_index']
ifacualmocap_pose['eyeLookUpLeft'] = emotion_pose['iris_rotation_x_index']
ifacualmocap_pose['eyeLookDownLeft'] = -emotion_pose['iris_rotation_x_index']
ifacualmocap_pose['eyeLookUpRight'] = emotion_pose['iris_rotation_x_index']
ifacualmocap_pose['eyeLookDownRight'] = -emotion_pose['iris_rotation_x_index']
ifacialmocap_pose['eyeLookInLeft'] = -emotion_pose['iris_rotation_y_index']
ifacialmocap_pose['eyeLookOutLeft'] = emotion_pose['iris_rotation_y_index']
ifacialmocap_pose['eyeLookInRight'] = emotion_pose['iris_rotation_y_index']
ifacialmocap_pose['eyeLookOutRight'] = -emotion_pose['iris_rotation_y_index']
ifacialmocap_pose['eyeLookUpLeft'] = emotion_pose['iris_rotation_x_index']
ifacialmocap_pose['eyeLookDownLeft'] = -emotion_pose['iris_rotation_x_index']
ifacialmocap_pose['eyeLookUpRight'] = emotion_pose['iris_rotation_x_index']
ifacialmocap_pose['eyeLookDownRight'] = -emotion_pose['iris_rotation_x_index']
# Update iris size values
ifacualmocap_pose['irisWideLeft'] = emotion_pose['iris_small_left_index']
ifacualmocap_pose['irisWideRight'] = emotion_pose['iris_small_right_index']
ifacialmocap_pose['irisWideLeft'] = emotion_pose['iris_small_left_index']
ifacialmocap_pose['irisWideRight'] = emotion_pose['iris_small_right_index']
# Update head rotation values
ifacualmocap_pose['headBoneX'] = -emotion_pose['head_x_index'] * 15.0
ifacualmocap_pose['headBoneY'] = -emotion_pose['head_y_index'] * 10.0
ifacualmocap_pose['headBoneZ'] = emotion_pose['neck_z_index'] * 15.0
ifacialmocap_pose['headBoneX'] = -emotion_pose['head_x_index'] * 15.0
ifacialmocap_pose['headBoneY'] = -emotion_pose['head_y_index'] * 10.0
ifacialmocap_pose['headBoneZ'] = emotion_pose['neck_z_index'] * 15.0
# Update mouth values
ifacualmocap_pose['mouthSmileLeft'] = emotion_pose['mouth_aaa_index']
ifacualmocap_pose['mouthSmileRight'] = emotion_pose['mouth_aaa_index']
ifacualmocap_pose['mouthFrownLeft'] = emotion_pose['mouth_lowered_corner_left_index']
ifacualmocap_pose['mouthFrownRight'] = emotion_pose['mouth_lowered_corner_right_index']
ifacualmocap_pose['mouthPressLeft'] = emotion_pose['mouth_raised_corner_left_index']
ifacualmocap_pose['mouthPressRight'] = emotion_pose['mouth_raised_corner_right_index']
ifacialmocap_pose['mouthSmileLeft'] = emotion_pose['mouth_aaa_index']
ifacialmocap_pose['mouthSmileRight'] = emotion_pose['mouth_aaa_index']
ifacialmocap_pose['mouthFrownLeft'] = emotion_pose['mouth_lowered_corner_left_index']
ifacialmocap_pose['mouthFrownRight'] = emotion_pose['mouth_lowered_corner_right_index']
ifacialmocap_pose['mouthPressLeft'] = emotion_pose['mouth_raised_corner_left_index']
ifacialmocap_pose['mouthPressRight'] = emotion_pose['mouth_raised_corner_right_index']
return ifacualmocap_pose
return ifacialmocap_pose
def update_blinking_pose(self, tranisitiondPose):
PARTS = ['wink_left_index', 'wink_right_index']
@@ -702,11 +709,11 @@ class MainFrame(wx.Frame):
return updated_list
def update_sway_pose_good(self, tranisitiondPose):
def update_sway_pose_good(self, tranisitiondPose): # TODO: good? why is there a bad one, too? keep only one!
MOVEPARTS = ['head_y_index']
updated_list = []
print( self.start_values, self.targets, self.progress, self.direction )
print(self.start_values, self.targets, self.progress, self.direction)
for item in tranisitiondPose:
key, value = item.split(': ')
@@ -725,11 +732,9 @@ class MainFrame(wx.Frame):
self.targets[key] = current_value + random.uniform(-1, 1)
self.progress[key] = 0 # Reset progress when setting a new target
# Use lerp to interpolate between start and target values
# Linearly interpolate between start and target values
new_value = self.start_values[key] + self.progress[key] * (self.targets[key] - self.start_values[key])
# Ensure the value remains within bounds (just in case)
new_value = min(max(new_value, -1), 1)
new_value = min(max(new_value, -1), 1) # clip to bounds (just in case)
# Update progress based on direction
self.progress[key] += 0.02 * self.direction[key]
@@ -744,7 +749,7 @@ class MainFrame(wx.Frame):
MOVEPARTS = ['head_y_index']
updated_list = []
#print( self.start_values, self.targets, self.progress, self.direction )
# print( self.start_values, self.targets, self.progress, self.direction )
for item in tranisitiondPose:
key, value = item.split(': ')
@@ -752,11 +757,9 @@ class MainFrame(wx.Frame):
if key in MOVEPARTS:
current_value = float(value)
# Use lerp to interpolate between start and target values
# Linearly interpolate between start and target values
new_value = self.start_values[key] + self.progress[key] * (self.targets[key] - self.start_values[key])
# Ensure the value remains within bounds (just in case)
new_value = min(max(new_value, -1), 1)
new_value = min(max(new_value, -1), 1) # clip to bounds (just in case)
# Check if we've reached the target or start value
is_close_to_target = abs(new_value - self.targets[key]) < 0.04
@@ -808,7 +811,7 @@ class MainFrame(wx.Frame):
if key in transition_dict:
# If the key is 'wink_left_index' or 'wink_right_index', set the value directly dont animate blinks
if key in ['wink_left_index', 'wink_right_index']: # BLINK FIX
if key in ['wink_left_index', 'wink_right_index']: # BLINK FIX
last_value = transition_dict[key]
# For all other keys, increment its value by 0.1 of the delta and clip it to the target
@@ -820,6 +823,7 @@ class MainFrame(wx.Frame):
updated_last_transition_pose.append(f"{key}: {last_value}")
# If any value is less than the target, set inMotion to True
# TODO/FIXME: inMotion is not actually used; what was this supposed to do?
if any(last_transition_dict[k] < transition_dict[k] for k in last_transition_dict if k in transition_dict):
inMotion = True
else:
@@ -847,63 +851,64 @@ class MainFrame(wx.Frame):
MainFrame.load_image(self, event=None, file_path=None) # call load_image function here
return
#OLD METHOD
#ifacialmocap_pose = self.animationMain() #GET ANIMATION CHANGES
#current_posesaved = self.pose_converter.convert(ifacialmocap_pose)
#combined_posesaved = current_posesaved
# # OLD METHOD
# ifacialmocap_pose = self.animationMain() # GET ANIMATION CHANGES
# current_posesaved = self.pose_converter.convert(ifacialmocap_pose)
# combined_posesaved = current_posesaved
#NEW METHOD
#CREATES THE DEFAULT POSE AND STORES OBJ IN STRING
#ifacialmocap_pose = self.animationMain() #DISABLE FOR TESTING!!!!!!!!!!!!!!!!!!!!!!!!
# NEW METHOD
# CREATES THE DEFAULT POSE AND STORES OBJ IN STRING
# ifacialmocap_pose = self.animationMain() # DISABLE FOR TESTING!!!!!!!!!!!!!!!!!!!!!!!!
ifacialmocap_pose = self.ifacialmocap_pose
#print("ifacialmocap_pose", ifacialmocap_pose)
# print("ifacialmocap_pose", ifacialmocap_pose)
#GET EMOTION SETTING
# GET EMOTION SETTING
emotion_pose = self.get_emotion_values(emotion)
#print("emotion_pose ", emotion_pose)
# print("emotion_pose ", emotion_pose)
#MERGE EMOTION SETTING WITH CURRENT OUTPUT
updated_pose = self.update_ifacualmocap_pose(ifacialmocap_pose, emotion_pose)
#print("updated_pose ", updated_pose)
# MERGE EMOTION SETTING WITH CURRENT OUTPUT
# NOTE: This is a mutating method that overwrites the original `ifacialmocap_pose`.
updated_pose = self.update_ifacialmocap_pose(ifacialmocap_pose, emotion_pose)
# print("updated_pose ", updated_pose)
#CONVERT RESULT TO FORMAT NN CAN USE
# CONVERT RESULT TO FORMAT NN CAN USE
current_pose = self.pose_converter.convert(updated_pose)
#print("current_pose ", current_pose)
# print("current_pose ", current_pose)
#SEND THROUGH CONVERT
# SEND THROUGH CONVERT
current_pose = self.pose_converter.convert(ifacialmocap_pose)
#print("current_pose2 ", current_pose)
# print("current_pose2 ", current_pose)
#ADD LABELS/NAMES TO THE POSE
# ADD LABELS/NAMES TO THE POSE
names_current_pose = MainFrame.addNamestoConvert(current_pose)
#print("current pose :", names_current_pose)
# print("current pose :", names_current_pose)
#GET THE EMOTION VALUES again for some reason
# GET THE EMOTION VALUES again for some reason
emotion_pose2 = self.get_emotion_values(emotion)
#print("target pose :", emotion_pose2)
# print("target pose :", emotion_pose2)
#APPLY VALUES TO THE POSE AGAIN?? This needs to overwrite the values
# APPLY VALUES TO THE POSE AGAIN?? This needs to overwrite the values
tranisitiondPose = self.animateToEmotion(names_current_pose, emotion_pose2)
#print("combine pose :", tranisitiondPose)
# print("combine pose :", tranisitiondPose)
#smooth animate
#print("LAST VALUES: ", lasttranisitiondPose)
#print("TARGER VALUES: ", tranisitiondPose)
# smooth animate
# print("LAST VALUES: ", lasttranisitiondPose)
# print("TARGER VALUES: ", tranisitiondPose)
if lasttranisitiondPose != "NotInit":
tranisitiondPose = self.update_transition_pose(lasttranisitiondPose, tranisitiondPose)
#print("smoothed: ", tranisitiondPose)
# print("smoothed: ", tranisitiondPose)
#Animate blinking
# Animate blinking
tranisitiondPose = self.update_blinking_pose(tranisitiondPose)
#Animate Head Sway
# Animate Head Sway
tranisitiondPose = self.update_sway_pose(tranisitiondPose)
#Animate Talking
# Animate Talking
tranisitiondPose = self.update_talking_pose(tranisitiondPose)
#reformat the data correctly
# reformat the data correctly
parsed_data = []
for item in tranisitiondPose:
key, value_str = item.split(': ')
@@ -911,7 +916,7 @@ class MainFrame(wx.Frame):
parsed_data.append((key, value))
tranisitiondPosenew = [value for _, value in parsed_data]
#not sure what this is for TBH
# not sure what this is for TBH
ifacialmocap_pose = tranisitiondPosenew
if self.torch_source_image is None:
@@ -921,7 +926,7 @@ class MainFrame(wx.Frame):
del dc
return
#pose = torch.tensor(tranisitiondPosenew, device=self.device, dtype=self.poser.get_dtype())
# pose = torch.tensor(tranisitiondPosenew, device=self.device, dtype=self.poser.get_dtype())
pose = self.dict_to_tensor(tranisitiondPosenew).to(device=self.device, dtype=self.poser.get_dtype())
with torch.no_grad():
@@ -931,27 +936,25 @@ class MainFrame(wx.Frame):
c, h, w = output_image.shape
output_image = (255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1)).reshape(h, w, c).byte()
numpy_image = output_image.detach().cpu().numpy()
wx_image = wx.ImageFromBuffer(numpy_image.shape[0],
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
numpy_image.shape[1],
numpy_image[:, :, 0:3].tobytes(),
numpy_image[:, :, 3].tobytes())
wx_bitmap = wx_image.ConvertToBitmap()
dc = wx.MemoryDC()
dc.SelectObject(self.result_image_bitmap)
dc.Clear()
dc.DrawBitmap(wx_bitmap,
(self.poser.get_image_size() - numpy_image.shape[0]) // 2,
(self.poser.get_image_size() - numpy_image.shape[1]) // 2, True)
(self.poser.get_image_size() - numpy_image.shape[0]) // 2,
(self.poser.get_image_size() - numpy_image.shape[1]) // 2, True)
numpy_image_bgra = numpy_image[:, :, [2, 1, 0, 3]] # Convert color channels from RGB to BGR and keep alpha channel
numpy_image_bgra = numpy_image[:, :, [2, 1, 0, 3]] # Convert color channels from RGB to BGR and keep alpha channel
global_result_image = numpy_image_bgra
del dc
time_now = time.time_ns()
if self.last_update_time is not None:
elapsed_time = time_now - self.last_update_time
@@ -963,16 +966,15 @@ class MainFrame(wx.Frame):
self.last_update_time = time_now
if(initAMI == True): #If the models are just now initalized stop animation to save
if initAMI: # If the models are just now initalized stop animation to save
global_timer_paused = True
initAMI = False
if random.random() <= 0.01:
trimmed_fps = round(fps, 1)
#print("talkinghead FPS: {:.1f}".format(trimmed_fps))
print("talkinghead FPS: {:.1f}".format(trimmed_fps))
#Store current pose to use as last pose on next loop
# Store current pose to use as last pose on next loop
lasttranisitiondPose = tranisitiondPose
self.Refresh()
@@ -1028,7 +1030,7 @@ class MainFrame(wx.Frame):
except Exception as error:
print("Error: ", error)
global_reload = None #reset the globe load
global_reload = None # Reset the globe load
self.Refresh()
if __name__ == "__main__":
@@ -1041,7 +1043,10 @@ if __name__ == "__main__":
choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
help='The model to use.'
)
parser.add_argument('--char', type=str, required=False, help='The path to the character image.')
parser.add_argument('--char',
type=str,
required=False,
help='The filename of the character image under "tha3/images/".')
parser.add_argument(
'--device',
type=str,
@@ -1052,4 +1057,5 @@ if __name__ == "__main__":
)
args = parser.parse_args()
launch_gui(device=args.device, model=args.model)
global_basedir = "" # in standalone mode, cwd is the "talkinghead" directory
launch_gui(device=args.device, model=args.model, char=args.char, standalone=True)

File diff suppressed because it is too large Load Diff

View File

@@ -1,19 +1,19 @@
import torch
def load_poser(model: str, device: torch.device):
def load_poser(model: str, device: torch.device, modelsdir="talkinghead/tha3/models"):
print("Using the %s model." % model)
if model == "standard_float":
from tha3.poser.modes.standard_float import create_poser
return create_poser(device)
return create_poser(device, modelsdir=modelsdir)
elif model == "standard_half":
from tha3.poser.modes.standard_half import create_poser
return create_poser(device)
return create_poser(device, modelsdir=modelsdir)
elif model == "separable_float":
from tha3.poser.modes.separable_float import create_poser
return create_poser(device)
return create_poser(device, modelsdir=modelsdir)
elif model == "separable_half":
from tha3.poser.modes.separable_half import create_poser
return create_poser(device)
return create_poser(device, modelsdir=modelsdir)
else:
raise RuntimeError("Invalid model: '%s'" % model)
raise RuntimeError("Invalid model: '%s'" % model)

View File

@@ -33,4 +33,4 @@ def get_pose_parameters():
.add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \
.build()
.build()

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import Dict, Optional, List
import os
from typing import Dict, List, Optional
import torch
from torch import Tensor
@@ -259,30 +260,31 @@ def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
default_output_index: int = 0,
modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_float"
file_name = dir + "/eyebrow_decomposer.pt"
file_name = os.path.join(modelsdir, "separable_float", "eyebrow_decomposer.pt")
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_float"
file_name = dir + "/eyebrow_morphing_combiner.pt"
file_name = os.path.join(modelsdir, "separable_float", "eyebrow_morphing_combiner.pt")
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_float"
file_name = dir + "/face_morpher.pt"
file_name = os.path.join(modelsdir, "separable_float", "face_morpher.pt")
module_file_names[Network.face_morpher.name] = file_name
if Network.two_algo_face_body_rotator.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_float"
file_name = dir + "/two_algo_face_body_rotator.pt"
file_name = os.path.join(modelsdir, "separable_float", "two_algo_face_body_rotator.pt")
module_file_names[Network.two_algo_face_body_rotator.name] = file_name
if Network.editor.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_float"
file_name = dir + "/editor.pt"
file_name = os.path.join(modelsdir, "separable_float", "editor.pt")
module_file_names[Network.editor.name] = file_name
# fail-fast
for file_name in module_file_names.values():
if not os.path.exists(file_name):
raise FileNotFoundError(f"Model file {file_name} not found, please check the path.")
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),
@@ -328,4 +330,4 @@ if __name__ == "__main__":
print("%d:" % i, elapsed_time)
acc = acc + elapsed_time
print("average:", acc / repeat)
print("average:", acc / repeat)

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import List, Dict, Optional
import os
from typing import Dict, List, Optional
import torch
from torch import Tensor
@@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_03 import \
EyebrowMorphingCombiner03Factory, EyebrowMorphingCombiner03Args, EyebrowMorphingCombiner03
from tha3.nn.face_morpher.face_morpher_09 import FaceMorpher09Factory, FaceMorpher09Args
from tha3.poser.general_poser_02 import GeneralPoser02
from tha3.poser.poser import PoseParameterCategory, PoseParameters
from tha3.nn.editor.editor_07 import Editor07, Editor07Args
from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \
TwoAlgoFaceBodyRotator05Args
from tha3.poser.modes.pose_parameters import get_pose_parameters
from tha3.util import torch_load
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc
from tha3.compute.cached_computation_protocol import CachedComputationProtocol
@@ -253,69 +254,35 @@ def load_editor(file_name) -> Module:
return module
def get_pose_parameters():
return PoseParameters.Builder() \
.add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \
.add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \
.add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \
.build()
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner03.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
default_output_index: int = 0,
modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_half"
file_name = dir + "/eyebrow_decomposer.pt"
file_name = os.path.join(modelsdir, "separable_half", "eyebrow_decomposer.pt")
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_half"
file_name = dir + "/eyebrow_morphing_combiner.pt"
file_name = os.path.join(modelsdir, "separable_half", "eyebrow_morphing_combiner.pt")
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_half"
file_name = dir + "/face_morpher.pt"
file_name = os.path.join(modelsdir, "separable_half", "face_morpher.pt")
module_file_names[Network.face_morpher.name] = file_name
if Network.two_algo_face_body_rotator.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_half"
file_name = dir + "/two_algo_face_body_rotator.pt"
file_name = os.path.join(modelsdir, "separable_half", "two_algo_face_body_rotator.pt")
module_file_names[Network.two_algo_face_body_rotator.name] = file_name
if Network.editor.name not in module_file_names:
dir = "talkinghead/tha3/models/separable_half"
file_name = dir + "/editor.pt"
file_name = os.path.join(modelsdir, "separable_half", "editor.pt")
module_file_names[Network.editor.name] = file_name
# fail-fast
for file_name in module_file_names.values():
if not os.path.exists(file_name):
raise FileNotFoundError(f"Model file {file_name} not found, please check the path.")
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import List, Dict, Optional
import os
from typing import Dict, List, Optional
import torch
from torch import Tensor
@@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \
EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00
from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory
from tha3.poser.general_poser_02 import GeneralPoser02
from tha3.poser.poser import PoseParameterCategory, PoseParameters
from tha3.nn.editor.editor_07 import Editor07, Editor07Args
from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \
TwoAlgoFaceBodyRotator05Args
from tha3.poser.modes.pose_parameters import get_pose_parameters
from tha3.util import torch_load
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc
from tha3.compute.cached_computation_protocol import CachedComputationProtocol
@@ -251,69 +252,35 @@ def load_editor(file_name) -> Module:
return module
def get_pose_parameters():
return PoseParameters.Builder() \
.add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \
.add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \
.add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \
.build()
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
default_output_index: int = 0,
modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_float"
file_name = dir + "/eyebrow_decomposer.pt"
file_name = os.path.join(modelsdir, "standard_float", "eyebrow_decomposer.pt")
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_float"
file_name = dir + "/eyebrow_morphing_combiner.pt"
file_name = os.path.join(modelsdir, "standard_float", "eyebrow_morphing_combiner.pt")
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_float"
file_name = dir + "/face_morpher.pt"
file_name = os.path.join(modelsdir, "standard_float", "face_morpher.pt")
module_file_names[Network.face_morpher.name] = file_name
if Network.two_algo_face_body_rotator.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_float"
file_name = dir + "/two_algo_face_body_rotator.pt"
file_name = os.path.join(modelsdir, "standard_float", "two_algo_face_body_rotator.pt")
module_file_names[Network.two_algo_face_body_rotator.name] = file_name
if Network.editor.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_float"
file_name = dir + "/editor.pt"
file_name = os.path.join(modelsdir, "standard_float", "editor.pt")
module_file_names[Network.editor.name] = file_name
# fail-fast
for file_name in module_file_names.values():
if not os.path.exists(file_name):
raise FileNotFoundError(f"Model file {file_name} not found, please check the path.")
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),

View File

@@ -1,5 +1,6 @@
from enum import Enum
from typing import List, Dict, Optional
import os
from typing import Dict, List, Optional
import torch
from torch import Tensor
@@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \
EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00
from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory
from tha3.poser.general_poser_02 import GeneralPoser02
from tha3.poser.poser import PoseParameterCategory, PoseParameters
from tha3.nn.editor.editor_07 import Editor07, Editor07Args
from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \
TwoAlgoFaceBodyRotator05Args
from tha3.poser.modes.pose_parameters import get_pose_parameters
from tha3.util import torch_load
from tha3.compute.cached_computation_func import TensorListCachedComputationFunc
from tha3.compute.cached_computation_protocol import CachedComputationProtocol
@@ -251,69 +252,35 @@ def load_editor(file_name) -> Module:
return module
def get_pose_parameters():
return PoseParameters.Builder() \
.add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \
.add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \
.add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \
.add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \
.add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \
.add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \
.add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \
.add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \
.build()
def create_poser(
device: torch.device,
module_file_names: Optional[Dict[str, str]] = None,
eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX,
default_output_index: int = 0) -> GeneralPoser02:
default_output_index: int = 0,
modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02:
if module_file_names is None:
module_file_names = {}
if Network.eyebrow_decomposer.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_half"
file_name = dir + "/eyebrow_decomposer.pt"
file_name = os.path.join(modelsdir, "standard_half", "eyebrow_decomposer.pt")
module_file_names[Network.eyebrow_decomposer.name] = file_name
if Network.eyebrow_morphing_combiner.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_half"
file_name = dir + "/eyebrow_morphing_combiner.pt"
file_name = os.path.join(modelsdir, "standard_half", "eyebrow_morphing_combiner.pt")
module_file_names[Network.eyebrow_morphing_combiner.name] = file_name
if Network.face_morpher.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_half"
file_name = dir + "/face_morpher.pt"
file_name = os.path.join(modelsdir, "standard_half", "face_morpher.pt")
module_file_names[Network.face_morpher.name] = file_name
if Network.two_algo_face_body_rotator.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_half"
file_name = dir + "/two_algo_face_body_rotator.pt"
file_name = os.path.join(modelsdir, "standard_half", "two_algo_face_body_rotator.pt")
module_file_names[Network.two_algo_face_body_rotator.name] = file_name
if Network.editor.name not in module_file_names:
dir = "talkinghead/tha3/models/standard_half"
file_name = dir + "/editor.pt"
file_name = os.path.join(modelsdir, "standard_half", "editor.pt")
module_file_names[Network.editor.name] = file_name
# fail-fast
for file_name in module_file_names.values():
if not os.path.exists(file_name):
raise FileNotFoundError(f"Model file {file_name} not found, please check the path.")
loaders = {
Network.eyebrow_decomposer.name:
lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),