diff --git a/talkinghead/start_manual_poser.sh b/talkinghead/start_manual_poser.sh new file mode 100755 index 0000000..bfb6b51 --- /dev/null +++ b/talkinghead/start_manual_poser.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# +# Launch the THA3 manual poser app. +# +# This app can be used to generate static expression images, given just +# one static input image in the appropriate format. +# +# This app is standalone, and does not interact with SillyTavern. +# +# This must run in the "extras" conda venv! +# Do this first: +# conda activate extras +# +python -m tha3.app.manual_poser $@ diff --git a/talkinghead/start_standalone_app.sh b/talkinghead/start_standalone_app.sh new file mode 100755 index 0000000..fbaeb48 --- /dev/null +++ b/talkinghead/start_standalone_app.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# +# Launch THA3 in standalone app mode. +# +# This standalone app mode does not interact with SillyTavern. +# +# The usual way to run this fork of THA3 is as a SillyTavern-extras plugin. +# The standalone app mode comes from the original THA3 code, and is included +# for testing and debugging. +# +# If you want to manually pose a character (to generate static expression images), +# use `start_manual_poser.sh` instead. +# +# This must run in the "extras" conda venv! +# Do this first: +# conda activate extras +# +# The `--char=...` flag can be used to specify which image to load under "tha3/images". +# +python -m tha3.app.app --char=example.png $@ diff --git a/talkinghead/tha3/app/app.py b/talkinghead/tha3/app/app.py index 50a038a..0d81e65 100644 --- a/talkinghead/tha3/app/app.py +++ b/talkinghead/tha3/app/app.py @@ -1,3 +1,5 @@ +# TODO: Standalone app mode does not work yet. The SillyTavern-extras plugin mode works. + import argparse import ast import os @@ -7,19 +9,18 @@ import threading import time import torch import io -import torch.nn.functional as F import wx import numpy as np import json +import typing from PIL import Image -from torchvision import transforms from flask import Flask, Response from flask_cors import CORS from io import BytesIO sys.path.append(os.getcwd()) -from tha3.mocap.ifacialmocap_constants import * +from tha3.mocap import ifacialmocap_constants as mocap_constants from tha3.mocap.ifacialmocap_pose import create_default_ifacialmocap_pose from tha3.mocap.ifacialmocap_pose_converter import IFacialMocapPoseConverter from tha3.mocap.ifacialmocap_poser_converter_25 import create_ifacialmocap_pose_converter @@ -31,7 +32,7 @@ from tha3.util import ( ) from typing import Optional -# Global Variables +# Global variables global_source_image = None global_result_image = None global_reload = None @@ -43,7 +44,7 @@ lasttranisitiondPose = "NotInit" inMotion = False fps = 0 current_pose = None -storepath = os.path.join(os.getcwd(), "talkinghead", "emotions") +global_basedir = "talkinghead" # for SillyTavern-extras live mode; if running standalone, we override this later # Flask setup app = Flask(__name__) @@ -60,7 +61,7 @@ def setEmotion(_emotion): highest_score = item['score'] highest_label = item['label'] - #print("Applying ", emotion) + # print("Applying ", emotion) emotion = highest_label def unload(): @@ -85,10 +86,10 @@ def result_feed(): try: rgb_image = global_result_image[:, :, [2, 1, 0]] # Swap B and R channels pil_image = Image.fromarray(np.uint8(rgb_image)) # Convert to PIL Image - if global_result_image.shape[2] == 4: # Check if there is an alpha channel present - alpha_channel = global_result_image[:, :, 3] # Extract alpha channel - pil_image.putalpha(Image.fromarray(np.uint8(alpha_channel))) # Set alpha channel in the PIL Image - buffer = io.BytesIO() # Save as PNG with RGBA mode + if global_result_image.shape[2] == 4: # Check if there is an alpha channel present + alpha_channel = global_result_image[:, :, 3] # Extract alpha channel + pil_image.putalpha(Image.fromarray(np.uint8(alpha_channel))) # Set alpha channel in the PIL Image + buffer = io.BytesIO() # Save as PNG with RGBA mode pil_image.save(buffer, format='PNG') image_bytes = buffer.getvalue() except Exception as e: @@ -100,19 +101,20 @@ def result_feed(): return Response(generate(), mimetype='multipart/x-mixed-replace; boundary=frame') def talkinghead_load_file(stream): + global global_basedir global global_source_image global global_reload global global_timer_paused global_timer_paused = False try: - pil_image = Image.open(stream) # Load the image using PIL.Image.open - img_data = BytesIO() # Create a copy of the image data in memory using BytesIO + pil_image = Image.open(stream) # Load the image using PIL.Image.open + img_data = BytesIO() # Create a copy of the image data in memory using BytesIO pil_image.save(img_data, format='PNG') - global_reload = Image.open(BytesIO(img_data.getvalue())) # Set the global_reload to the copy of the image data + global_reload = Image.open(BytesIO(img_data.getvalue())) # Set the global_reload to the copy of the image data except Image.UnidentifiedImageError: - print(f"Could not load image from file, loading blank") - full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.sep.join(["talkinghead", "tha3", "images", "inital.png"]))) + print("Could not load image from file, loading blank") + full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.join(global_basedir, "tha3", "images", "inital.png"))) MainFrame.load_image(None, full_path) global_timer_paused = True return 'OK' @@ -121,32 +123,41 @@ def convert_linear_to_srgb(image: torch.Tensor) -> torch.Tensor: rgb_image = torch_linear_to_srgb(image[0:3, :, :]) return torch.cat([rgb_image, image[3:4, :, :]], dim=0) -def launch_gui(device, model): +def launch_gui(device: str, model: str, char: typing.Optional[str] = None, standalone: bool = False): + """ + device: "cpu" or "cuda" + model: one of the folder names inside "talkinghead/tha3/models/" + char: name of png file inside "talkinghead/tha3/images/"; if not given, defaults to "inital.png". + """ + global global_basedir global initAMI initAMI = True + # TODO: We could use this to parse the arguments that were provided to `server.py`, but we don't currently use the parser output. parser = argparse.ArgumentParser(description='uWu Waifu') - # Add other parser arguments here - args, unknown = parser.parse_known_args() - try: - poser = load_poser(model, device) - pose_converter = create_ifacialmocap_pose_converter() #creates a list of 45 + if char is None: + char = "inital.png" - app = wx.App(redirect=False) + try: + poser = load_poser(model, device, modelsdir=os.path.join(global_basedir, "tha3", "models")) + pose_converter = create_ifacialmocap_pose_converter() # creates a list of 45 + + app = wx.App() main_frame = MainFrame(poser, pose_converter, device) main_frame.SetSize((750, 600)) - #Lload default image (you can pass args.char if required) - full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.sep.join(["talkinghead", "tha3", "images", "inital.png"]))) + # Load character image + full_path = os.path.join(os.getcwd(), os.path.normpath(os.path.join(global_basedir, "tha3", "images", char))) main_frame.load_image(None, full_path) - #main_frame.Show(True) + if standalone: + main_frame.Show(True) main_frame.capture_timer.Start(100) main_frame.animation_timer.Start(100) - wx.DisableAsserts() #prevent popup about debug alert closed from other threads + wx.DisableAsserts() # prevent popup about debug alert closed from other threads app.MainLoop() except RuntimeError as e: @@ -235,7 +246,7 @@ class MainFrame(wx.Frame): current_pose = self.ifacialmocap_pose # NOTE: randomize mouth - for blendshape_name in BLENDSHAPE_NAMES: + for blendshape_name in mocap_constants.BLENDSHAPE_NAMES: if "jawOpen" in blendshape_name: if is_talking or is_talking_override: current_pose[blendshape_name] = self.random_generate_value(-5000, 5000, abs(1 - current_pose[blendshape_name])) @@ -247,7 +258,7 @@ class MainFrame(wx.Frame): def animationHeadMove(self): current_pose = self.ifacialmocap_pose - for key in [HEAD_BONE_Y]: #can add more to this list if needed + for key in [mocap_constants.HEAD_BONE_Y]: # can add more to this list if needed current_pose[key] = self.random_generate_value(-20, 20, current_pose[key]) return current_pose @@ -265,29 +276,30 @@ class MainFrame(wx.Frame): return current_pose def addNamestoConvert(pose): + # TODO: What are the unknown keys? index_to_name = { - 0: 'eyebrow_troubled_left_index', #COMBACK TO UNK - 1: 'eyebrow_troubled_right_index',#COMBACK TO UNK + 0: 'eyebrow_troubled_left_index', + 1: 'eyebrow_troubled_right_index', 2: 'eyebrow_angry_left_index', 3: 'eyebrow_angry_right_index', - 4: 'unknown1', #COMBACK TO UNK - 5: 'unknown2', #COMBACK TO UNK + 4: 'unknown1', # COMBACK TO UNK + 5: 'unknown2', # COMBACK TO UNK 6: 'eyebrow_raised_left_index', 7: 'eyebrow_raised_right_index', 8: 'eyebrow_happy_left_index', 9: 'eyebrow_happy_right_index', - 10: 'unknown3', #COMBACK TO UNK - 11: 'unknown4', #COMBACK TO UNK + 10: 'unknown3', # COMBACK TO UNK + 11: 'unknown4', # COMBACK TO UNK 12: 'wink_left_index', 13: 'wink_right_index', 14: 'eye_happy_wink_left_index', 15: 'eye_happy_wink_right_index', 16: 'eye_surprised_left_index', 17: 'eye_surprised_right_index', - 18: 'unknown5', #COMBACK TO UNK - 19: 'unknown6', #COMBACK TO UNK - 20: 'unknown7', #COMBACK TO UNK - 21: 'unknown8', #COMBACK TO UNK + 18: 'unknown5', # COMBACK TO UNK + 19: 'unknown6', # COMBACK TO UNK + 20: 'unknown7', # COMBACK TO UNK + 21: 'unknown8', # COMBACK TO UNK 22: 'eye_raised_lower_eyelid_left_index', 23: 'eye_raised_lower_eyelid_right_index', 24: 'iris_small_left_index', @@ -295,14 +307,14 @@ class MainFrame(wx.Frame): 26: 'mouth_aaa_index', 27: 'mouth_iii_index', 28: 'mouth_ooo_index', - 29: 'unknown9a', #COMBACK TO UNK + 29: 'unknown9a', # COMBACK TO UNK 30: 'mouth_ooo_index2', - 31: 'unknown9', #COMBACK TO UNK - 32: 'unknown10', #COMBACK TO UNK - 33: 'unknown11', #COMBACK TO UNK + 31: 'unknown9', # COMBACK TO UNK + 32: 'unknown10', # COMBACK TO UNK + 33: 'unknown11', # COMBACK TO UNK 34: 'mouth_raised_corner_left_index', 35: 'mouth_raised_corner_right_index', - 36: 'unknown12', + 36: 'unknown12', # COMBACK TO UNK 37: 'iris_rotation_x_index', 38: 'iris_rotation_y_index', 39: 'head_x_index', @@ -321,17 +333,16 @@ class MainFrame(wx.Frame): return output - def get_emotion_values(self, emotion): # Place to define emotion presets - global storepath + def get_emotion_values(self, emotion): # Place to define emotion presets + global global_basedir - #print(emotion) - file_path = os.path.join(storepath, emotion + ".json") - #print("trying: ", file_path) + # print(emotion) + file_path = os.path.join(global_basedir, "emotions", emotion + ".json") + # print("trying: ", file_path) if not os.path.exists(file_path): - print("using backup for: ", file_path) - file_path = os.path.join(storepath, "_defaults.json") - + print("using backup for: ", file_path) + file_path = os.path.join(global_basedir, "emotions", "_defaults.json") with open(file_path, 'r') as json_file: emotions = json.load(json_file) @@ -339,8 +350,8 @@ class MainFrame(wx.Frame): targetpose = emotions.get(emotion, {}) targetpose_values = targetpose - #targetpose_values = list(targetpose.values()) - #print("targetpose: ", targetpose, "for ", emotion) + # targetpose_values = list(targetpose.values()) + # print("targetpose: ", targetpose, "for ", emotion) return targetpose_values def animateToEmotion(self, current_pose_list, target_pose_dict): @@ -362,9 +373,9 @@ class MainFrame(wx.Frame): return transitionPose def animationMain(self): - self.ifacialmocap_pose = self.animationBlink() - self.ifacialmocap_pose = self.animationHeadMove() - self.ifacialmocap_pose = self.animationTalking() + self.ifacialmocap_pose = self.animationBlink() + self.ifacialmocap_pose = self.animationHeadMove() + self.ifacialmocap_pose = self.animationTalking() return self.ifacialmocap_pose def filter_by_index(self, current_pose_list, index): @@ -414,7 +425,6 @@ class MainFrame(wx.Frame): self.animation_left_panel_sizer.Fit(self.animation_left_panel) # Right Column (Sliders) - self.animation_right_panel = wx.Panel(self.animation_panel, style=wx.SIMPLE_BORDER) self.animation_right_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.animation_right_panel.SetSizer(self.animation_right_panel_sizer) @@ -442,15 +452,13 @@ class MainFrame(wx.Frame): self.output_background_choice.SetSelection(0) self.animation_right_panel_sizer.Add(self.output_background_choice, 0, wx.EXPAND) - - - + # These are applied to `ifacialmocap_pose`, so we can only use names that are defined there (see `update_ifacialmocap_pose`). blendshape_groups = { 'Eyes': ['eyeLookOutLeft', 'eyeLookOutRight', 'eyeLookDownLeft', 'eyeLookUpLeft', 'eyeWideLeft', 'eyeWideRight'], - 'Mouth': ['mouthFrownLeft'], + 'Mouth': ['mouthSmileLeft', 'mouthFrownLeft'], 'Cheek': ['cheekSquintLeft', 'cheekSquintRight', 'cheekPuff'], 'Brow': ['browDownLeft', 'browOuterUpLeft', 'browDownRight', 'browOuterUpRight', 'browInnerUp'], - 'Eyelash': ['mouthSmileLeft'], + # 'Eyelash': [], 'Nose': ['noseSneerLeft', 'noseSneerRight'], 'Misc': ['tongueOut'] } @@ -492,26 +500,26 @@ class MainFrame(wx.Frame): def on_slider_change(self, event): slider = event.GetEventObject() value = slider.GetValue() / 100.0 # Divide by 100 to get the actual float value - #print(value) + # print(value) slider_name = slider.GetName() self.ifacialmocap_pose[slider_name] = value def create_ui(self): - #MAke the UI Elements + # Make the UI Elements self.main_sizer = wx.BoxSizer(wx.VERTICAL) self.SetSizer(self.main_sizer) self.SetAutoLayout(1) self.capture_pose_lock = threading.Lock() - #Main panel with JPS + # Main panel with JPS self.create_animation_panel(self) self.main_sizer.Add(self.animation_panel, wx.SizerFlags(0).Expand().Border(wx.ALL, 5)) def update_capture_panel(self, event: wx.Event): data = self.ifacialmocap_pose - for rotation_name in ROTATION_NAMES: - value = data[rotation_name] + for rotation_name in mocap_constants.ROTATION_NAMES: + value = data[rotation_name] # TODO/FIXME: updating unused variable; what was this supposed to do? @staticmethod def convert_to_100(x): @@ -561,7 +569,7 @@ class MainFrame(wx.Frame): pose_dict = dict(zip(pose_names, combine_pose)) return pose_dict - def determine_data_type(self, data): + def determine_data_type(self, data): # TODO: is this needed, nothing in the project seems to call it; and why not just use `isinstance` directly? if isinstance(data, list): print("It's a list.") elif isinstance(data, dict): @@ -600,8 +608,8 @@ class MainFrame(wx.Frame): else: raise ValueError("Unsupported data type passed to dict_to_tensor.") - def update_ifacualmocap_pose(self, ifacualmocap_pose, emotion_pose): - # Update Values - The following values are in emotion_pose but not defined in ifacualmocap_pose + def update_ifacialmocap_pose(self, ifacialmocap_pose, emotion_pose): + # Update Values - The following values are in emotion_pose but not defined in ifacialmocap_pose # eye_happy_wink_left_index, eye_happy_wink_right_index # eye_surprised_left_index, eye_surprised_right_index # eye_relaxed_left_index, eye_relaxed_right_index @@ -616,56 +624,55 @@ class MainFrame(wx.Frame): # body_z_index # breathing_index - - ifacualmocap_pose['browDownLeft'] = emotion_pose['eyebrow_troubled_left_index'] - ifacualmocap_pose['browDownRight'] = emotion_pose['eyebrow_troubled_right_index'] - ifacualmocap_pose['browOuterUpLeft'] = emotion_pose['eyebrow_angry_left_index'] - ifacualmocap_pose['browOuterUpRight'] = emotion_pose['eyebrow_angry_right_index'] - ifacualmocap_pose['browInnerUp'] = emotion_pose['eyebrow_happy_left_index'] - ifacualmocap_pose['browInnerUp'] += emotion_pose['eyebrow_happy_right_index'] - ifacualmocap_pose['browDownLeft'] = emotion_pose['eyebrow_raised_left_index'] - ifacualmocap_pose['browDownRight'] = emotion_pose['eyebrow_raised_right_index'] - ifacualmocap_pose['browDownLeft'] += emotion_pose['eyebrow_lowered_left_index'] - ifacualmocap_pose['browDownRight'] += emotion_pose['eyebrow_lowered_right_index'] - ifacualmocap_pose['browDownLeft'] += emotion_pose['eyebrow_serious_left_index'] - ifacualmocap_pose['browDownRight'] += emotion_pose['eyebrow_serious_right_index'] + ifacialmocap_pose['browDownLeft'] = emotion_pose['eyebrow_troubled_left_index'] + ifacialmocap_pose['browDownRight'] = emotion_pose['eyebrow_troubled_right_index'] + ifacialmocap_pose['browOuterUpLeft'] = emotion_pose['eyebrow_angry_left_index'] + ifacialmocap_pose['browOuterUpRight'] = emotion_pose['eyebrow_angry_right_index'] + ifacialmocap_pose['browInnerUp'] = emotion_pose['eyebrow_happy_left_index'] + ifacialmocap_pose['browInnerUp'] += emotion_pose['eyebrow_happy_right_index'] + ifacialmocap_pose['browDownLeft'] = emotion_pose['eyebrow_raised_left_index'] + ifacialmocap_pose['browDownRight'] = emotion_pose['eyebrow_raised_right_index'] + ifacialmocap_pose['browDownLeft'] += emotion_pose['eyebrow_lowered_left_index'] + ifacialmocap_pose['browDownRight'] += emotion_pose['eyebrow_lowered_right_index'] + ifacialmocap_pose['browDownLeft'] += emotion_pose['eyebrow_serious_left_index'] + ifacialmocap_pose['browDownRight'] += emotion_pose['eyebrow_serious_right_index'] # Update eye values - ifacualmocap_pose['eyeWideLeft'] = emotion_pose['eye_surprised_left_index'] - ifacualmocap_pose['eyeWideRight'] = emotion_pose['eye_surprised_right_index'] + ifacialmocap_pose['eyeWideLeft'] = emotion_pose['eye_surprised_left_index'] + ifacialmocap_pose['eyeWideRight'] = emotion_pose['eye_surprised_right_index'] # Update eye blink (though we will overwrite it later) - ifacualmocap_pose['eyeBlinkLeft'] = emotion_pose['eye_wink_left_index'] - ifacualmocap_pose['eyeBlinkRight'] = emotion_pose['eye_wink_right_index'] + ifacialmocap_pose['eyeBlinkLeft'] = emotion_pose['eye_wink_left_index'] + ifacialmocap_pose['eyeBlinkRight'] = emotion_pose['eye_wink_right_index'] # Update iris rotation values - ifacualmocap_pose['eyeLookInLeft'] = -emotion_pose['iris_rotation_y_index'] - ifacualmocap_pose['eyeLookOutLeft'] = emotion_pose['iris_rotation_y_index'] - ifacualmocap_pose['eyeLookInRight'] = emotion_pose['iris_rotation_y_index'] - ifacualmocap_pose['eyeLookOutRight'] = -emotion_pose['iris_rotation_y_index'] - ifacualmocap_pose['eyeLookUpLeft'] = emotion_pose['iris_rotation_x_index'] - ifacualmocap_pose['eyeLookDownLeft'] = -emotion_pose['iris_rotation_x_index'] - ifacualmocap_pose['eyeLookUpRight'] = emotion_pose['iris_rotation_x_index'] - ifacualmocap_pose['eyeLookDownRight'] = -emotion_pose['iris_rotation_x_index'] + ifacialmocap_pose['eyeLookInLeft'] = -emotion_pose['iris_rotation_y_index'] + ifacialmocap_pose['eyeLookOutLeft'] = emotion_pose['iris_rotation_y_index'] + ifacialmocap_pose['eyeLookInRight'] = emotion_pose['iris_rotation_y_index'] + ifacialmocap_pose['eyeLookOutRight'] = -emotion_pose['iris_rotation_y_index'] + ifacialmocap_pose['eyeLookUpLeft'] = emotion_pose['iris_rotation_x_index'] + ifacialmocap_pose['eyeLookDownLeft'] = -emotion_pose['iris_rotation_x_index'] + ifacialmocap_pose['eyeLookUpRight'] = emotion_pose['iris_rotation_x_index'] + ifacialmocap_pose['eyeLookDownRight'] = -emotion_pose['iris_rotation_x_index'] # Update iris size values - ifacualmocap_pose['irisWideLeft'] = emotion_pose['iris_small_left_index'] - ifacualmocap_pose['irisWideRight'] = emotion_pose['iris_small_right_index'] + ifacialmocap_pose['irisWideLeft'] = emotion_pose['iris_small_left_index'] + ifacialmocap_pose['irisWideRight'] = emotion_pose['iris_small_right_index'] # Update head rotation values - ifacualmocap_pose['headBoneX'] = -emotion_pose['head_x_index'] * 15.0 - ifacualmocap_pose['headBoneY'] = -emotion_pose['head_y_index'] * 10.0 - ifacualmocap_pose['headBoneZ'] = emotion_pose['neck_z_index'] * 15.0 + ifacialmocap_pose['headBoneX'] = -emotion_pose['head_x_index'] * 15.0 + ifacialmocap_pose['headBoneY'] = -emotion_pose['head_y_index'] * 10.0 + ifacialmocap_pose['headBoneZ'] = emotion_pose['neck_z_index'] * 15.0 # Update mouth values - ifacualmocap_pose['mouthSmileLeft'] = emotion_pose['mouth_aaa_index'] - ifacualmocap_pose['mouthSmileRight'] = emotion_pose['mouth_aaa_index'] - ifacualmocap_pose['mouthFrownLeft'] = emotion_pose['mouth_lowered_corner_left_index'] - ifacualmocap_pose['mouthFrownRight'] = emotion_pose['mouth_lowered_corner_right_index'] - ifacualmocap_pose['mouthPressLeft'] = emotion_pose['mouth_raised_corner_left_index'] - ifacualmocap_pose['mouthPressRight'] = emotion_pose['mouth_raised_corner_right_index'] + ifacialmocap_pose['mouthSmileLeft'] = emotion_pose['mouth_aaa_index'] + ifacialmocap_pose['mouthSmileRight'] = emotion_pose['mouth_aaa_index'] + ifacialmocap_pose['mouthFrownLeft'] = emotion_pose['mouth_lowered_corner_left_index'] + ifacialmocap_pose['mouthFrownRight'] = emotion_pose['mouth_lowered_corner_right_index'] + ifacialmocap_pose['mouthPressLeft'] = emotion_pose['mouth_raised_corner_left_index'] + ifacialmocap_pose['mouthPressRight'] = emotion_pose['mouth_raised_corner_right_index'] - return ifacualmocap_pose + return ifacialmocap_pose def update_blinking_pose(self, tranisitiondPose): PARTS = ['wink_left_index', 'wink_right_index'] @@ -702,11 +709,11 @@ class MainFrame(wx.Frame): return updated_list - def update_sway_pose_good(self, tranisitiondPose): + def update_sway_pose_good(self, tranisitiondPose): # TODO: good? why is there a bad one, too? keep only one! MOVEPARTS = ['head_y_index'] updated_list = [] - print( self.start_values, self.targets, self.progress, self.direction ) + print(self.start_values, self.targets, self.progress, self.direction) for item in tranisitiondPose: key, value = item.split(': ') @@ -725,11 +732,9 @@ class MainFrame(wx.Frame): self.targets[key] = current_value + random.uniform(-1, 1) self.progress[key] = 0 # Reset progress when setting a new target - # Use lerp to interpolate between start and target values + # Linearly interpolate between start and target values new_value = self.start_values[key] + self.progress[key] * (self.targets[key] - self.start_values[key]) - - # Ensure the value remains within bounds (just in case) - new_value = min(max(new_value, -1), 1) + new_value = min(max(new_value, -1), 1) # clip to bounds (just in case) # Update progress based on direction self.progress[key] += 0.02 * self.direction[key] @@ -744,7 +749,7 @@ class MainFrame(wx.Frame): MOVEPARTS = ['head_y_index'] updated_list = [] - #print( self.start_values, self.targets, self.progress, self.direction ) + # print( self.start_values, self.targets, self.progress, self.direction ) for item in tranisitiondPose: key, value = item.split(': ') @@ -752,11 +757,9 @@ class MainFrame(wx.Frame): if key in MOVEPARTS: current_value = float(value) - # Use lerp to interpolate between start and target values + # Linearly interpolate between start and target values new_value = self.start_values[key] + self.progress[key] * (self.targets[key] - self.start_values[key]) - - # Ensure the value remains within bounds (just in case) - new_value = min(max(new_value, -1), 1) + new_value = min(max(new_value, -1), 1) # clip to bounds (just in case) # Check if we've reached the target or start value is_close_to_target = abs(new_value - self.targets[key]) < 0.04 @@ -808,7 +811,7 @@ class MainFrame(wx.Frame): if key in transition_dict: # If the key is 'wink_left_index' or 'wink_right_index', set the value directly dont animate blinks - if key in ['wink_left_index', 'wink_right_index']: # BLINK FIX + if key in ['wink_left_index', 'wink_right_index']: # BLINK FIX last_value = transition_dict[key] # For all other keys, increment its value by 0.1 of the delta and clip it to the target @@ -820,6 +823,7 @@ class MainFrame(wx.Frame): updated_last_transition_pose.append(f"{key}: {last_value}") # If any value is less than the target, set inMotion to True + # TODO/FIXME: inMotion is not actually used; what was this supposed to do? if any(last_transition_dict[k] < transition_dict[k] for k in last_transition_dict if k in transition_dict): inMotion = True else: @@ -847,63 +851,64 @@ class MainFrame(wx.Frame): MainFrame.load_image(self, event=None, file_path=None) # call load_image function here return - #OLD METHOD - #ifacialmocap_pose = self.animationMain() #GET ANIMATION CHANGES - #current_posesaved = self.pose_converter.convert(ifacialmocap_pose) - #combined_posesaved = current_posesaved + # # OLD METHOD + # ifacialmocap_pose = self.animationMain() # GET ANIMATION CHANGES + # current_posesaved = self.pose_converter.convert(ifacialmocap_pose) + # combined_posesaved = current_posesaved - #NEW METHOD - #CREATES THE DEFAULT POSE AND STORES OBJ IN STRING - #ifacialmocap_pose = self.animationMain() #DISABLE FOR TESTING!!!!!!!!!!!!!!!!!!!!!!!! + # NEW METHOD + # CREATES THE DEFAULT POSE AND STORES OBJ IN STRING + # ifacialmocap_pose = self.animationMain() # DISABLE FOR TESTING!!!!!!!!!!!!!!!!!!!!!!!! ifacialmocap_pose = self.ifacialmocap_pose - #print("ifacialmocap_pose", ifacialmocap_pose) + # print("ifacialmocap_pose", ifacialmocap_pose) - #GET EMOTION SETTING + # GET EMOTION SETTING emotion_pose = self.get_emotion_values(emotion) - #print("emotion_pose ", emotion_pose) + # print("emotion_pose ", emotion_pose) - #MERGE EMOTION SETTING WITH CURRENT OUTPUT - updated_pose = self.update_ifacualmocap_pose(ifacialmocap_pose, emotion_pose) - #print("updated_pose ", updated_pose) + # MERGE EMOTION SETTING WITH CURRENT OUTPUT + # NOTE: This is a mutating method that overwrites the original `ifacialmocap_pose`. + updated_pose = self.update_ifacialmocap_pose(ifacialmocap_pose, emotion_pose) + # print("updated_pose ", updated_pose) - #CONVERT RESULT TO FORMAT NN CAN USE + # CONVERT RESULT TO FORMAT NN CAN USE current_pose = self.pose_converter.convert(updated_pose) - #print("current_pose ", current_pose) + # print("current_pose ", current_pose) - #SEND THROUGH CONVERT + # SEND THROUGH CONVERT current_pose = self.pose_converter.convert(ifacialmocap_pose) - #print("current_pose2 ", current_pose) + # print("current_pose2 ", current_pose) - #ADD LABELS/NAMES TO THE POSE + # ADD LABELS/NAMES TO THE POSE names_current_pose = MainFrame.addNamestoConvert(current_pose) - #print("current pose :", names_current_pose) + # print("current pose :", names_current_pose) - #GET THE EMOTION VALUES again for some reason + # GET THE EMOTION VALUES again for some reason emotion_pose2 = self.get_emotion_values(emotion) - #print("target pose :", emotion_pose2) + # print("target pose :", emotion_pose2) - #APPLY VALUES TO THE POSE AGAIN?? This needs to overwrite the values + # APPLY VALUES TO THE POSE AGAIN?? This needs to overwrite the values tranisitiondPose = self.animateToEmotion(names_current_pose, emotion_pose2) - #print("combine pose :", tranisitiondPose) + # print("combine pose :", tranisitiondPose) - #smooth animate - #print("LAST VALUES: ", lasttranisitiondPose) - #print("TARGER VALUES: ", tranisitiondPose) + # smooth animate + # print("LAST VALUES: ", lasttranisitiondPose) + # print("TARGER VALUES: ", tranisitiondPose) if lasttranisitiondPose != "NotInit": tranisitiondPose = self.update_transition_pose(lasttranisitiondPose, tranisitiondPose) - #print("smoothed: ", tranisitiondPose) + # print("smoothed: ", tranisitiondPose) - #Animate blinking + # Animate blinking tranisitiondPose = self.update_blinking_pose(tranisitiondPose) - #Animate Head Sway + # Animate Head Sway tranisitiondPose = self.update_sway_pose(tranisitiondPose) - #Animate Talking + # Animate Talking tranisitiondPose = self.update_talking_pose(tranisitiondPose) - #reformat the data correctly + # reformat the data correctly parsed_data = [] for item in tranisitiondPose: key, value_str = item.split(': ') @@ -911,7 +916,7 @@ class MainFrame(wx.Frame): parsed_data.append((key, value)) tranisitiondPosenew = [value for _, value in parsed_data] - #not sure what this is for TBH + # not sure what this is for TBH ifacialmocap_pose = tranisitiondPosenew if self.torch_source_image is None: @@ -921,7 +926,7 @@ class MainFrame(wx.Frame): del dc return - #pose = torch.tensor(tranisitiondPosenew, device=self.device, dtype=self.poser.get_dtype()) + # pose = torch.tensor(tranisitiondPosenew, device=self.device, dtype=self.poser.get_dtype()) pose = self.dict_to_tensor(tranisitiondPosenew).to(device=self.device, dtype=self.poser.get_dtype()) with torch.no_grad(): @@ -931,27 +936,25 @@ class MainFrame(wx.Frame): c, h, w = output_image.shape output_image = (255.0 * torch.transpose(output_image.reshape(c, h * w), 0, 1)).reshape(h, w, c).byte() - numpy_image = output_image.detach().cpu().numpy() wx_image = wx.ImageFromBuffer(numpy_image.shape[0], - numpy_image.shape[1], - numpy_image[:, :, 0:3].tobytes(), - numpy_image[:, :, 3].tobytes()) + numpy_image.shape[1], + numpy_image[:, :, 0:3].tobytes(), + numpy_image[:, :, 3].tobytes()) wx_bitmap = wx_image.ConvertToBitmap() dc = wx.MemoryDC() dc.SelectObject(self.result_image_bitmap) dc.Clear() dc.DrawBitmap(wx_bitmap, - (self.poser.get_image_size() - numpy_image.shape[0]) // 2, - (self.poser.get_image_size() - numpy_image.shape[1]) // 2, True) + (self.poser.get_image_size() - numpy_image.shape[0]) // 2, + (self.poser.get_image_size() - numpy_image.shape[1]) // 2, True) - numpy_image_bgra = numpy_image[:, :, [2, 1, 0, 3]] # Convert color channels from RGB to BGR and keep alpha channel + numpy_image_bgra = numpy_image[:, :, [2, 1, 0, 3]] # Convert color channels from RGB to BGR and keep alpha channel global_result_image = numpy_image_bgra del dc - time_now = time.time_ns() if self.last_update_time is not None: elapsed_time = time_now - self.last_update_time @@ -963,16 +966,15 @@ class MainFrame(wx.Frame): self.last_update_time = time_now - if(initAMI == True): #If the models are just now initalized stop animation to save + if initAMI: # If the models are just now initalized stop animation to save global_timer_paused = True initAMI = False if random.random() <= 0.01: trimmed_fps = round(fps, 1) - #print("talkinghead FPS: {:.1f}".format(trimmed_fps)) + print("talkinghead FPS: {:.1f}".format(trimmed_fps)) - - #Store current pose to use as last pose on next loop + # Store current pose to use as last pose on next loop lasttranisitiondPose = tranisitiondPose self.Refresh() @@ -1028,7 +1030,7 @@ class MainFrame(wx.Frame): except Exception as error: print("Error: ", error) - global_reload = None #reset the globe load + global_reload = None # Reset the globe load self.Refresh() if __name__ == "__main__": @@ -1041,7 +1043,10 @@ if __name__ == "__main__": choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'], help='The model to use.' ) - parser.add_argument('--char', type=str, required=False, help='The path to the character image.') + parser.add_argument('--char', + type=str, + required=False, + help='The filename of the character image under "tha3/images/".') parser.add_argument( '--device', type=str, @@ -1052,4 +1057,5 @@ if __name__ == "__main__": ) args = parser.parse_args() - launch_gui(device=args.device, model=args.model) + global_basedir = "" # in standalone mode, cwd is the "talkinghead" directory + launch_gui(device=args.device, model=args.model, char=args.char, standalone=True) diff --git a/talkinghead/tha3/app/manual_poser.py b/talkinghead/tha3/app/manual_poser.py index 86cae9a..f66b2f3 100644 --- a/talkinghead/tha3/app/manual_poser.py +++ b/talkinghead/tha3/app/manual_poser.py @@ -1,32 +1,269 @@ +"""THA3 manual poser. + +Pose an anime character manually, based on a suitable 512×512 static input image and some neural networks. + + +**What**: + +This app is an alternative to the live plugin mode of `talkinghead`. Given one static input image, +this allows the automatic generation of the 28 emotional expression sprites for your AI character, +for use with distilbert classification in SillyTavern. + +There are two motivations: + + - Much faster than inpainting all 28 expressions manually in Stable Diffusion. Enables agile experimentation + on the look of your character, since you only need to produce one new image to change the look. + - No CPU or GPU load while running SillyTavern, unlike the live plugin mode, which is cool, but slow. + +For best results for generating the static input image in Stable Diffusion, consider the various vtuber checkpoints +available on the internet. These should reduce the amount of work it takes to get SD to render your character in +a pose suitable for use as input. + +Results are often not perfect, but serviceable. + + +**How**: + +To run the manual poser, ensure that you have the correct wxPython installed in your "extras" conda venv, +open a terminal in the SillyTavern-extras top-level directory, and do the following: + + cd talkinghead + conda activate extras + ./start_manual_poser.sh + +Note that installing wxPython needs `libgtk-3-dev` (on Debian based distros), +so `sudo apt install libgtk-3-dev` before trying to `pip install wxPython`. +The install may take a very long time (even half an hour) as it needs to +compile a whole GUI toolkit. + + +**Who**: + +Original code written and neural networks designed and trained by Pramook Khungurn (@pkhungurn): + https://github.com/pkhungurn/talking-head-anime-3-demo + https://arxiv.org/abs/2311.17409 + +This fork maintained by the SillyTavern-extras project. + +Manual poser app improved and documented by Juha Jeronen (@Technologicat). +""" + +# TODO: manual poser: +# - Write new README: use case and supported features are different from the original THA3 package. +# - Refactor stuff needed both by this and the live mode that is served by `app.py`. + import argparse +import json import logging import os +import pathlib import sys -import PIL.Image -import numpy -import torch -import wx -import json -from typing import List +import time +from typing import Dict, List, Tuple -# Set the working directory to the "live2d" subdirectory to work with file structure -target_directory = os.path.join(os.getcwd(), "live2d") -os.chdir(target_directory) -sys.path.append(os.getcwd()) +import PIL.Image + +import numpy + +import torch + +import wx from tha3.poser.modes.load_poser import load_poser from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup -from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \ +from tha3.util import rgba_to_numpy_image, grid_change_to_numpy_image, \ rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image -current_directory = os.getcwd() -parent_directory = os.path.dirname(current_directory) -os.chdir(parent_directory) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Detect image file formats supported by the installed Pillow, and format a list for wxPython file open/save dialogs. +# TODO: This is not very useful unless we can filter these to get only formats that support an alpha channel. +# +# https://docs.wxpython.org/wx.FileDialog.html +# https://stackoverflow.com/questions/71112986/retrieve-a-list-of-supported-read-file-extensions-formats +# +# exts = PIL.Image.registered_extensions() +# PIL_supported_input_formats = {ex[1:].lower() for ex, f in exts.items() if f in PIL.Image.OPEN} # {".png", ".jpg", ...} -> {"png", "jpg", ...} +# PIL_supported_output_formats = {ex[1:].lower() for ex, f in exts.items() if f in PIL.Image.SAVE} +# def format_fileformat_list(supported_formats): +# return ["All files (*)|*"] + [f"{fmt.upper()} images (*.{fmt})|*.{fmt}" for fmt in sorted(supported_formats)] +# input_index_to_ext = [""] + sorted(PIL_supported_input_formats) # list index -> file extension +# input_ext_to_index = {ext: idx for idx, ext in enumerate(input_index_to_ext)} # file extension -> list index +# output_index_to_ext = [""] + sorted(PIL_supported_output_formats) +# output_ext_to_index = {ext: idx for idx, ext in enumerate(output_index_to_ext)} +# input_exts_and_descs_str = "|".join(format_fileformat_list(PIL_supported_input_formats)) # filter-spec accepted by `wx.FileDialog` +# output_exts_and_descs_str = "|".join(format_fileformat_list(PIL_supported_output_formats)) + +# The keys for a pose in the emotion JSON files. +# +# TODO: "eye_unimpressed" is arity-2, but has only one entry in the JSON. The current implementation smashes both into one, +# letting the second one (right slider) win. Maybe the two values should be saved separately, but we have to avoid +# breaking the live mode served by `app.py`. +posedict_keys = ["eyebrow_troubled_left_index", "eyebrow_troubled_right_index", + "eyebrow_angry_left_index", "eyebrow_angry_right_index", + "eyebrow_lowered_left_index", "eyebrow_lowered_right_index", + "eyebrow_raised_left_index", "eyebrow_raised_right_index", + "eyebrow_happy_left_index", "eyebrow_happy_right_index", + "eyebrow_serious_left_index", "eyebrow_serious_right_index", + "eye_wink_left_index", "eye_wink_right_index", + "eye_happy_wink_left_index", "eye_happy_wink_right_index", + "eye_surprised_left_index", "eye_surprised_right_index", + "eye_relaxed_left_index", "eye_relaxed_right_index", + "eye_unimpressed", "eye_unimpressed", + "eye_raised_lower_eyelid_left_index", "eye_raised_lower_eyelid_right_index", + "iris_small_left_index", "iris_small_right_index", + "mouth_aaa_index", + "mouth_iii_index", + "mouth_uuu_index", + "mouth_eee_index", + "mouth_ooo_index", + "mouth_delta", + "mouth_lowered_corner_left_index", "mouth_lowered_corner_right_index", + "mouth_raised_corner_left_index", "mouth_raised_corner_right_index", + "mouth_smirk", + "iris_rotation_x_index", "iris_rotation_y_index", + "head_x_index", "head_y_index", + "neck_z_index", + "body_y_index", "body_z_index", + "breathing_index"] +assert len(posedict_keys) == 45 + + +def load_emotion_presets() -> Tuple[Dict[str, Dict[str, float]], List[str]]: + """Load emotion presets from disk. + + These are JSON files in "talkinghead/emotions". + + Returns the tuple `(emotions, emotion_names)`, where:: + + emotions = {emotion0_name: posedict0, ...} + emotion_names = [emotion0_name, emotion1_name, ...] + + The dict contains the actual pose data. The list is a sorted list of emotion names + that can be used to map a linear index (e.g. the choice index in a GUI dropdown) + to the corresponding key of `emotions`. + + The directory "talkinghead/emotions" must also contain a "_defaults.json" file, + containing factory defaults (as a fallback) for the 28 standard emotions + (as recognized by distilbert), as well as a hidden "zero" preset that represents + a neutral pose. (This is separate from the "neutral" emotion, which is allowed + to be "non-zero".) + """ + emotion_names = [] + for root, dirs, files in os.walk("emotions", topdown=True): + for filename in files: + if filename == "_defaults.json": # skip the repository containing the default fallbacks + continue + if filename.lower().endswith(".json"): + emotion_names.append(filename[:-5]) # drop the ".json" + emotion_names.sort() # the 28 actual emotions + + # TODO: Note that currently, we build the list of emotion names from JSON filenames, + # and then check whether each JSON implements the emotion matching its filename. + # On second thought, I'm not sure whether that makes much sense. Maybe rethink the design. + # - We *do* want custom JSON files to show up in the list, if those are placed in "tha3/emotions". So the list of emotions shouldn't be hardcoded. + # - *Having* a fallback repository with factory defaults (and a hidden "zero" preset) is useful. + # But we are currently missing a way to reset an emotion to its factory default. + def load_emotion_with_fallback(emotion_name: str) -> Dict[str, float]: + try: + with open(os.path.join("emotions", f"{emotion_name}.json"), "r") as json_file: + emotions_from_json = json.load(json_file) # A single json file may contain presets for multiple emotions. + posedict = emotions_from_json[emotion_name] + except (FileNotFoundError, KeyError): # If no separate json exists for the specified emotion, load the default (all 28 emotions have a default). + with open(os.path.join("emotions", "_defaults.json"), "r") as json_file: + emotions_from_json = json.load(json_file) + posedict = emotions_from_json[emotion_name] + # If still not found, it's an error, so fail-fast: let the app exit with an informative exception message. + return posedict + + # Dict keeps its keys in insertion order, so define some special states before inserting the actual emotions. + emotions = {"[custom]": {}, # custom = the user has changed at least one value manually after last loading a preset + "[reset]": load_emotion_with_fallback("zero")} # reset = a preset with all sliders in their default positions. Found in "_defaults.json". + for emotion_name in emotion_names: + emotions[emotion_name] = load_emotion_with_fallback(emotion_name) + + emotion_names = list(emotions.keys()) + return emotions, emotion_names + + +class SimpleParamGroupsControlPanel(wx.Panel): + """A simple control panel for groups of arity-1 continuous parameters (i.e. float value, and no separate left/right controls). + + The panel represents a *category*, such as "body rotation". + + A category may have several *parameter groups*, all of which are active simultaneously. Here "parameter group" is a misnomer, + since in all use sites for this panel, each group has only one parameter. For example, "body rotation" has the groups ["body_y", "body_z"]. + """ + + def __init__(self, parent, + pose_param_category: PoseParameterCategory, + param_groups: List[PoseParameterGroup]): + super().__init__(parent, style=wx.SIMPLE_BORDER) + self.sizer = wx.BoxSizer(wx.VERTICAL) + self.SetSizer(self.sizer) + self.SetAutoLayout(1) + + self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] + for param_group in self.param_groups: + assert not param_group.is_discrete() + assert param_group.get_arity() == 1 + + self.sliders = [] + for param_group in self.param_groups: + title_text = wx.StaticText(self, label=param_group.get_group_name(), style=wx.ALIGN_CENTER) + title_text.SetFont(title_text.GetFont().Bold()) + self.sizer.Add(title_text, 0, wx.EXPAND) + # HACK: iris_rotation_*, head_*, body_* have range [-1, 1], but breathing has range [0, 1], + # and all of them should default to the *value* 0. + range = param_group.get_range() + min_value = int(range[0] * 1000) + max_value = int(range[1] * 1000) + slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) + self.sizer.Add(slider, 0, wx.EXPAND) + self.sliders.append(slider) + + self.sizer.Fit(self) + + def write_to_pose(self, pose: List[float]) -> None: + """Update `pose` (in-place) by the current value(s) set in this control panel.""" + for param_group, slider in zip(self.param_groups, self.sliders): + alpha = (slider.GetValue() - slider.GetMin()) / (slider.GetMax() - slider.GetMin()) + param_index = param_group.get_parameter_index() + param_range = param_group.get_range() + pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha + + def read_from_pose(self, pose: List[float]) -> None: + """Overwrite the current value(s) in this control panel by those taken from `pose`.""" + for param_group, slider in zip(self.param_groups, self.sliders): + param_range = param_group.get_range() + param_index = param_group.get_parameter_index() + value = pose[param_index] # cherry-pick only relevant values from `pose` + alpha = (value - param_range[0]) / (param_range[1] - param_range[0]) + slider.SetValue(int(slider.GetMin() + alpha * (slider.GetMax() - slider.GetMin()))) + class MorphCategoryControlPanel(wx.Panel): + """A more complex control panel with grouping semantics. + + The panel represents a *category*, such as "eyebrow". + + A category may have several *parameter groups*, only one of which can be active at any given time. + + For example, the "eyebrow" category has the parameter groups ["eyebrow_troubled", "eyebrow_angry", ...]. + + Each parameter group can be: + - Continuous with arity 1 (one slider), + - Continuous with arity 2 (two sliders, for separate left/right control), or + - Discrete (on/off). + + The panel allows the user to select a parameter group within the category, and enables/disables its + UI controls appropriately. The user can then use the controls to set the values for the selected + parameter group within the category represented by the panel. + """ def __init__(self, parent, - title: str, + category_title: str, pose_param_category: PoseParameterCategory, param_groups: List[PoseParameterGroup]): super().__init__(parent, style=wx.SIMPLE_BORDER) @@ -35,11 +272,13 @@ class MorphCategoryControlPanel(wx.Panel): self.SetSizer(self.sizer) self.SetAutoLayout(1) - title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER) - self.sizer.Add(title_text, 0, wx.EXPAND) + self.title_text = wx.StaticText(self, label=category_title, style=wx.ALIGN_CENTER) + self.title_text.SetFont(self.title_text.GetFont().Bold()) + self.sizer.Add(self.title_text, 0, wx.EXPAND) self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] - self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups]) + self.param_group_names = [group.get_group_name() for group in self.param_groups] + self.choice = wx.Choice(self, choices=self.param_group_names) if len(self.param_groups) > 0: self.choice.SetSelection(0) self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated) @@ -59,7 +298,8 @@ class MorphCategoryControlPanel(wx.Panel): self.sizer.Fit(self) - def update_ui(self): + def update_ui(self) -> None: + """Enable/disable UI controls based on the currently active parameter group.""" param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): self.left_slider.Enable(False) @@ -74,13 +314,26 @@ class MorphCategoryControlPanel(wx.Panel): self.right_slider.Enable(True) self.checkbox.Enable(False) - def on_choice_updated(self, event: wx.Event): + def on_choice_updated(self, event: wx.Event) -> None: + """Automatically optimize usability for the new arity and discrete/continuous state.""" param_group = self.param_groups[self.choice.GetSelection()] if param_group.is_discrete(): - self.checkbox.SetValue(True) + self.checkbox.SetValue(True) # discrete parameter group: set to "on" when switched into + self.left_slider.SetValue(self.left_slider.GetMin()) + self.right_slider.SetValue(self.right_slider.GetMin()) + else: + if param_group.get_arity() == 2: # make it apparent that both sliders are in use now + self.right_slider.SetValue(self.left_slider.GetValue()) # ...by copying value left->right + else: # arity 1, right slider not in use, so zero it out visually. + self.right_slider.SetValue(self.right_slider.GetMin()) self.update_ui() + event.Skip() # allow other handlers for the same event to run - def set_param_value(self, pose: List[float]): + def write_to_pose(self, pose: List[float]) -> None: + """Update `pose` (in-place) by the current value(s) set in this control panel. + + Only the currently chosen parameter group is applied. + """ if len(self.param_groups) == 0: return selected_morph_index = self.choice.GetSelection() @@ -92,52 +345,63 @@ class MorphCategoryControlPanel(wx.Panel): pose[param_index + i] = 1.0 else: param_range = param_group.get_range() - alpha = (self.left_slider.GetValue() + 1000) / 2000.0 + alpha = (self.left_slider.GetValue() - self.left_slider.GetMin()) * 1.0 / (self.left_slider.GetMax() - self.left_slider.GetMin()) # -> [0, 1] pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha if param_group.get_arity() == 2: - alpha = (self.right_slider.GetValue() + 1000) / 2000.0 + alpha = (self.right_slider.GetValue() - self.right_slider.GetMin()) * 1.0 / (self.right_slider.GetMax() - self.right_slider.GetMin()) pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha + def read_from_pose(self, pose: List[float]) -> None: + """Overwrite the current value(s) in this control panel by those taken from `pose`. -class SimpleParamGroupsControlPanel(wx.Panel): - def __init__(self, parent, - pose_param_category: PoseParameterCategory, - param_groups: List[PoseParameterGroup]): - super().__init__(parent, style=wx.SIMPLE_BORDER) - self.sizer = wx.BoxSizer(wx.VERTICAL) - self.SetSizer(self.sizer) - self.SetAutoLayout(1) + All parameter groups in this panel are scanned to find a nonzero value in `pose`. + The parameter group that first finds a nonzero value wins, selects its morph for this panel, + and applies the values to the sliders in the panel. - self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category] - for param_group in self.param_groups: - assert not param_group.is_discrete() - assert param_group.get_arity() == 1 - - self.sliders = [] - for param_group in self.param_groups: - static_text = wx.StaticText( - self, - label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER) - self.sizer.Add(static_text, 0, wx.EXPAND) - range = param_group.get_range() - min_value = int(range[0] * 1000) - max_value = int(range[1] * 1000) - slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL) - self.sizer.Add(slider, 0, wx.EXPAND) - self.sliders.append(slider) - - self.sizer.Fit(self) - - def set_param_value(self, pose: List[float]): - if len(self.param_groups) == 0: - return - for param_group_index in range(len(self.param_groups)): - param_group = self.param_groups[param_group_index] - slider = self.sliders[param_group_index] - param_range = param_group.get_range() + If nothing matches, the first available morph is selected, and the sliders are set to zero. + """ + # Find which morph (param group) is active in our category in `pose`. + for morph_index, param_group in enumerate(self.param_groups): param_index = param_group.get_parameter_index() - alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin()) - pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha + value = pose[param_index] + if value != 0.0: + break + # An arity-2 param group is active also when just the right slider is nonzero. + if param_group.get_arity() == 2: + value = pose[param_index + 1] + if value != 0.0: + break + else: # No param group in this panel's category had a nonzero value in `pose`. + if len(self.param_groups) > 0: + logger.debug(f"category {self.title_text.GetLabel()}: no nonzero values, chose default morph {self.param_group_names[0]}") + self.choice.SetSelection(0) # choose the first param group + self.left_slider.SetValue(self.left_slider.GetMin()) + self.right_slider.SetValue(self.right_slider.GetMin()) + self.checkbox.SetValue(False) + self.update_ui() + return + logger.debug(f"category {self.title_text.GetLabel()}: found nonzero values, chose morph {self.param_group_names[morph_index]}") + self.choice.SetSelection(morph_index) + if param_group.is_discrete(): + self.left_slider.SetValue(self.left_slider.GetMin()) + self.right_slider.SetValue(self.right_slider.GetMin()) + if pose[param_index]: + self.checkbox.SetValue(True) + else: + self.checkbox.SetValue(False) + else: + self.checkbox.SetValue(False) + param_range = param_group.get_range() + value = pose[param_index] + alpha = (value - param_range[0]) / (param_range[1] - param_range[0]) + self.left_slider.SetValue(int(self.left_slider.GetMin() + alpha * (self.left_slider.GetMax() - self.left_slider.GetMin()))) + if param_group.get_arity() == 2: + value = pose[param_index + 1] + alpha = (value - param_range[0]) / (param_range[1] - param_range[0]) + self.right_slider.SetValue(int(self.right_slider.GetMin() + alpha * (self.right_slider.GetMax() - self.right_slider.GetMin()))) + else: # arity 1, right slider not in use, so zero it out visually. + self.right_slider.SetValue(self.right_slider.GetMin()) + self.update_ui() def convert_output_image_from_torch_to_numpy(output_image): @@ -155,14 +419,64 @@ def convert_output_image_from_torch_to_numpy(output_image): elif output_image.shape[0] == 2: numpy_image = grid_change_to_numpy_image(output_image, num_channels=4) else: - raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0]) + raise RuntimeError(f"Unsupported # image channels: {output_image.shape[0]}") numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0)) return numpy_image +class FpsStatistics: + def __init__(self): + self.count = 100 + self.fps = [] + + def add_fps(self, fps: float) -> None: + self.fps.append(fps) + while len(self.fps) > self.count: + del self.fps[0] + + def get_average_fps(self) -> float: + if len(self.fps) == 0: + return 0.0 + else: + return sum(self.fps) / len(self.fps) + + +class MyFileDropTarget(wx.FileDropTarget): + def OnDropFiles(self, x, y, filenames): + if len(filenames) > 1: + return False + filename = filenames[0] + if filename.lower().endswith(".png"): + logger.info(f"Accepting drop for {filename}") + main_frame.load_image(filename) + return True + elif filename.lower().endswith(".json"): + logger.info(f"Accepting drop for {filename}") + main_frame.load_json(filename) + return True + logger.info(f"Rejecting drop for {filename}, unsupported file type") + return False + + class MainFrame(wx.Frame): - def __init__(self, poser: Poser, device: torch.device): - super().__init__(None, wx.ID_ANY, "Poser") + """Main app window for THA3 Manual Poser. + + Usage, roughly:: + + from tha3.poser.modes.load_poser import load_poser + + model = "separable_float" # or some other directory containing a model, under "tha3/models" + device = torch.device("cuda") # or "cpu", but then will be slow + poser = load_poser(model, device, modelsdir="tha3/models") + + app = wx.App() + main_frame = MainFrame(poser, device, model) + main_frame.Show(True) + main_frame.timer.Start(30) + app.MainLoop() + """ + def __init__(self, poser: Poser, device: torch.device, model: str): + super().__init__(None, wx.ID_ANY, f"THA3 Manual Poser [{device}] [{model}]") self.poser = poser self.dtype = self.poser.get_dtype() self.device = device @@ -179,17 +493,48 @@ class MainFrame(wx.Frame): self.init_right_panel() self.main_sizer.Fit(self) + self.fps_statistics = FpsStatistics() + self.timer = wx.Timer(self, wx.ID_ANY) self.Bind(wx.EVT_TIMER, self.update_images, self.timer) + load_image_id = wx.NewIdRef() + load_json_id = wx.NewIdRef() save_image_id = wx.NewIdRef() + save_batch_id = wx.NewIdRef() + focus_preset_id = wx.NewIdRef() + focus_editor_id = wx.NewIdRef() + focus_outputindex_id = wx.NewIdRef() + def focus_presets(event: wx.Event) -> None: + self.emotion_choice.SetFocus() + # TODO: Add hotkeys for each morph control group, and for the non-morph control groups. + def focus_editor(event: wx.Event) -> None: + if not self.morph_control_panels: + return + first_morph_control_panel = list(self.morph_control_panels.values())[0] + first_morph_control_panel.choice.SetFocus() + def focus_output_index(event: wx.Event) -> None: + self.output_index_choice.SetFocus() + self.Bind(wx.EVT_MENU, self.on_load_image, id=load_image_id) + self.Bind(wx.EVT_MENU, self.on_load_json, id=load_json_id) self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id) + self.Bind(wx.EVT_MENU, self.on_save_all_emotions, id=save_batch_id) + self.Bind(wx.EVT_MENU, focus_presets, id=focus_preset_id) + self.Bind(wx.EVT_MENU, focus_editor, id=focus_editor_id) + self.Bind(wx.EVT_MENU, focus_output_index, id=focus_outputindex_id) accelerator_table = wx.AcceleratorTable([ - (wx.ACCEL_CTRL, ord('S'), save_image_id) + (wx.ACCEL_CTRL, ord("O"), load_image_id), + (wx.ACCEL_CTRL | wx.ACCEL_SHIFT, ord("O"), load_json_id), + (wx.ACCEL_CTRL, ord("S"), save_image_id), + (wx.ACCEL_CTRL | wx.ACCEL_SHIFT, ord("S"), save_batch_id), + (wx.ACCEL_CTRL, ord("P"), focus_preset_id), + (wx.ACCEL_CTRL, ord("E"), focus_editor_id), + (wx.ACCEL_CTRL, ord("I"), focus_outputindex_id) ]) self.SetAcceleratorTable(accelerator_table) self.last_pose = None + self.last_emotion_index = None self.last_output_index = self.output_index_choice.GetSelection() self.last_output_numpy_image = None @@ -198,47 +543,81 @@ class MainFrame(wx.Frame): self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size) self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size) self.source_image_dirty = True + self.update_in_progress = False - def init_left_panel(self): + def on_erase_background(self, event: wx.Event) -> None: + pass + + def on_pose_edited(self, event: wx.Event) -> None: + """Automatically choose the '[custom]' emotion preset (to indicate edited state) when the pose is manually edited.""" + self.emotion_choice.SetSelection(0) + self.last_emotion_index = 0 + event.Skip() # allow other handlers for the same event to run + + def init_left_panel(self) -> None: + """Initialize the input image and emotion preset panel.""" self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1)) self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) - left_panel_sizer = wx.BoxSizer(wx.VERTICAL) - self.left_panel.SetSizer(left_panel_sizer) + self.left_panel_sizer = wx.BoxSizer(wx.VERTICAL) + self.left_panel.SetSizer(self.left_panel_sizer) self.left_panel.SetAutoLayout(1) self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size), style=wx.SIMPLE_BORDER) self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel) self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background) - left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) + self.file_drop_target = MyFileDropTarget() + self.source_image_panel.SetDropTarget(self.file_drop_target) + self.left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE) - self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n") - left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) - self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image) + # Emotion picker. + self.emotions, self.emotion_names = load_emotion_presets() - left_panel_sizer.Fit(self.left_panel) + # # Horizontal emotion picker layout; looks bad, text label vertical alignment is wrong. + # self.emotion_panel = wx.Panel(self.left_panel, style=wx.SIMPLE_BORDER, size=(-1, -1)) + # self.emotion_panel_sizer = wx.BoxSizer(wx.HORIZONTAL) + # self.emotion_panel.SetSizer(self.emotion_panel_sizer) + # self.emotion_panel.SetAutoLayout(1) + # self.emotion_panel_sizer.Add(wx.StaticText(self.emotion_panel, label="Emotion presets", style=wx.ALIGN_CENTRE_HORIZONTAL)) + # self.emotion_choice = wx.Choice(self.emotion_panel, choices=self.emotion_names) + # self.emotion_choice.SetSelection(0) + # self.emotion_panel_sizer.Add(self.emotion_choice, 0, wx.EXPAND) + # left_panel_sizer.Add(self.emotion_panel, 0, wx.EXPAND) + + # Vertical emotion picker layout. + self.left_panel_sizer.Add(wx.StaticText(self.left_panel, label="Emotion preset [Ctrl+P]", style=wx.ALIGN_LEFT)) + self.emotion_choice = wx.Choice(self.left_panel, choices=self.emotion_names) + self.emotion_choice.SetSelection(0) + self.left_panel_sizer.Add(self.emotion_choice, 0, wx.EXPAND) + + self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad image [Ctrl+O]\n\n") + self.left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND) + self.load_image_button.Bind(wx.EVT_BUTTON, self.on_load_image) + + self.load_json_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad JSON [Ctrl+Shift+O]\n\n") + self.left_panel_sizer.Add(self.load_json_button, 1, wx.EXPAND) + self.load_json_button.Bind(wx.EVT_BUTTON, self.on_load_json) + + self.left_panel_sizer.Fit(self.left_panel) self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE) - def on_erase_background(self, event: wx.Event): - pass - - def init_control_panel(self): + def init_control_panel(self) -> None: + """Initialize the pose editor panel.""" self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.control_panel.SetSizer(self.control_panel_sizer) self.control_panel.SetMinSize(wx.Size(256, 1)) - morph_categories = [ - PoseParameterCategory.EYEBROW, - PoseParameterCategory.EYE, - PoseParameterCategory.MOUTH, - PoseParameterCategory.IRIS_MORPH - ] - morph_category_titles = { - PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ", - PoseParameterCategory.EYE: " ------------ Eye ------------ ", - PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ", - PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ", - } + self.control_panel_sizer.Add(wx.StaticText(self.control_panel, label="Editor [Ctrl+E]", style=wx.ALIGN_CENTER), + wx.SizerFlags().Expand()) + + morph_categories = [PoseParameterCategory.EYEBROW, + PoseParameterCategory.EYE, + PoseParameterCategory.MOUTH, + PoseParameterCategory.IRIS_MORPH] + morph_category_titles = {PoseParameterCategory.EYEBROW: "Eyebrow", + PoseParameterCategory.EYE: "Eye", + PoseParameterCategory.MOUTH: "Mouth", + PoseParameterCategory.IRIS_MORPH: "Iris"} self.morph_control_panels = {} for category in morph_categories: param_groups = self.poser.get_pose_parameter_groups() @@ -250,32 +629,42 @@ class MainFrame(wx.Frame): morph_category_titles[category], category, self.poser.get_pose_parameter_groups()) + # Trigger the choice of the "[custom]" emotion preset when the pose is edited in this panel. + control_panel.choice.Bind(wx.EVT_CHOICE, self.on_pose_edited) + control_panel.left_slider.Bind(wx.EVT_SLIDER, self.on_pose_edited) + control_panel.right_slider.Bind(wx.EVT_SLIDER, self.on_pose_edited) + control_panel.checkbox.Bind(wx.EVT_CHECKBOX, self.on_pose_edited) self.morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) self.non_morph_control_panels = {} - non_morph_categories = [ - PoseParameterCategory.IRIS_ROTATION, - PoseParameterCategory.FACE_ROTATION, - PoseParameterCategory.BODY_ROTATION, - PoseParameterCategory.BREATHING - ] + non_morph_categories = [PoseParameterCategory.IRIS_ROTATION, + PoseParameterCategory.FACE_ROTATION, + PoseParameterCategory.BODY_ROTATION, + PoseParameterCategory.BREATHING] for category in non_morph_categories: param_groups = self.poser.get_pose_parameter_groups() filtered_param_groups = [group for group in param_groups if group.get_category() == category] if len(filtered_param_groups) == 0: continue - control_panel = SimpleParamGroupsControlPanel( - self.control_panel, - category, - self.poser.get_pose_parameter_groups()) + control_panel = SimpleParamGroupsControlPanel(self.control_panel, + category, + self.poser.get_pose_parameter_groups()) + # Trigger the choice of the "[custom]" emotion preset when the pose is edited in this panel. + for slider in control_panel.sliders: + slider.Bind(wx.EVT_SLIDER, self.on_pose_edited) self.non_morph_control_panels[category] = control_panel self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND) + self.fps_text = wx.StaticText(self.control_panel, label="FPS counter will appear here") + self.fps_text.SetForegroundColour((0, 255, 0)) + self.control_panel_sizer.Add(self.fps_text, wx.SizerFlags().Border()) + self.control_panel_sizer.Fit(self.control_panel) self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE) - def init_right_panel(self): + def init_right_panel(self) -> None: + """Initialize the output image and output controls panel.""" self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER) right_panel_sizer = wx.BoxSizer(wx.VERTICAL) self.right_panel.SetSizer(right_panel_sizer) @@ -291,16 +680,22 @@ class MainFrame(wx.Frame): choices=[str(i) for i in range(self.poser.get_output_length())]) self.output_index_choice.SetSelection(0) right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE) + right_panel_sizer.Add(wx.StaticText(self.right_panel, label="Output index [Ctrl+I] [meaning depends on the model]", style=wx.ALIGN_LEFT)) right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND) - self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n") + self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave image and JSON [Ctrl+S]\n\n") right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND) self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image) + self.save_all_emotions_button = wx.Button(self.right_panel, wx.ID_ANY, "\nBatch save image and JSON from all presets [Ctrl+Shift+S]\n\n") + right_panel_sizer.Add(self.save_all_emotions_button, 1, wx.EXPAND) + self.save_all_emotions_button.Bind(wx.EVT_BUTTON, self.on_save_all_emotions) + right_panel_sizer.Fit(self.right_panel) self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE) - def create_param_category_choice(self, param_category: PoseParameterCategory): + def create_param_category_choice(self, param_category: PoseParameterCategory) -> wx.Choice: + """Create a `wx.Choice` dropdown for the given pose parameter category (eyebrow, eye, ...).""" params = [] for param_group in self.poser.get_pose_parameter_groups(): if param_group.get_category() == param_category: @@ -310,191 +705,452 @@ class MainFrame(wx.Frame): choice.SetSelection(0) return choice - def load_image(self, event: wx.Event): - dir_name = "data/images" - file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN) - if file_dialog.ShowModal() == wx.ID_OK: - image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + def on_load_image(self, event: wx.Event) -> None: + """Ask the user for and load an input image.""" + dir_name = "tha3/images" # This is where `example.png` is. + file_dialog = wx.FileDialog(self, "Load input image", dir_name, "", "PNG files (*.png)|*.png", wx.FD_OPEN) + try: + if file_dialog.ShowModal() == wx.ID_OK: + image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + self.load_image(image_file_name) + finally: + file_dialog.Destroy() + + def on_load_json(self, event: wx.Event) -> None: + """Ask the user for and load a custom emotion JSON file.""" + dir_name = "output" # This is where "Save image and JSON" puts them by default, so... + file_dialog = wx.FileDialog(self, "Load JSON", dir_name, "", "JSON files (*.json)|*.json", wx.FD_OPEN) + try: + if file_dialog.ShowModal() == wx.ID_OK: + json_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) + self.load_json(json_file_name) + finally: + file_dialog.Destroy() + + def load_image(self, image_file_name: str) -> None: + """Load an input image.""" + try: + pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name), + (self.poser.get_image_size(), self.poser.get_image_size())) + w, h = pil_image.size + if pil_image.mode != "RGBA": # input image must have an alpha channel + self.wx_source_image = None + self.torch_source_image = None + logger.warning(f"Incompatible input image (no alpha channel), canceling load: {image_file_name}") + else: + logger.info(f"Loaded input image: {image_file_name}") + self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) + self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\ + .to(self.device).to(self.dtype) + self.source_image_dirty = True + self.Refresh() + self.Update() + except Exception as exc: + logger.error(f"Could not load image {image_file_name}, reason: {exc}") + message_dialog = wx.MessageDialog(self, f"Could not load image {image_file_name}, reason: {exc}", "THA3 Manual Poser", wx.OK) try: - pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name), - (self.poser.get_image_size(), self.poser.get_image_size())) - w, h = pil_image.size - if pil_image.mode != 'RGBA': - self.source_image_string = "Image must have alpha channel!" - self.wx_source_image = None - self.torch_source_image = None - else: - self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes()) - self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\ - .to(self.device).to(self.dtype) - self.source_image_dirty = True + message_dialog.ShowModal() + finally: + message_dialog.Destroy() + + def load_json(self, json_file_name: str) -> None: + """Load a custom emotion JSON file.""" + try: + # Load the emotion JSON file + with open(json_file_name, "r") as json_file: + emotions_from_json = json.load(json_file) + # TODO: Here we just take the first emotion from the file. + if not emotions_from_json: + logger.warning(f"No emotions defined in given JSON file, canceling load: {json_file_name}") + return + first_emotion_name = list(emotions_from_json.keys())[0] # first in insertion order, i.e. topmost in file + if len(emotions_from_json) > 1: + logger.warning(f"File {json_file_name} contains multiple emotions, loading the first one '{first_emotion_name}'.") + posedict = emotions_from_json[first_emotion_name] + pose = self.posedict_to_pose(posedict) + + # Apply loaded emotion + self.set_current_pose(pose) + + # Auto-select "[custom]" + self.emotion_choice.SetSelection(0) + + # Do the GUI update after any pending events have processed + def on_load_json_cont(): self.Refresh() self.Update() - except: - message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK) + wx.CallAfter(on_load_json_cont) + except Exception as exc: + logger.error(f"Could not load JSON {json_file_name}, reason: {exc}") + message_dialog = wx.MessageDialog(self, f"Could not load JSON {json_file_name}, reason: {exc}", "THA3 Manual Poser", wx.OK) + try: message_dialog.ShowModal() + finally: message_dialog.Destroy() - file_dialog.Destroy() + else: + logger.info(f"Loaded JSON {json_file_name}") - def paint_source_image_panel(self, event: wx.Event): + def paint_source_image_panel(self, event: wx.Event) -> None: wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap) - def paint_result_image_panel(self, event: wx.Event): + def paint_result_image_panel(self, event: wx.Event) -> None: wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap) - def draw_nothing_yet_string_to_bitmap(self, bitmap): + def draw_message_to_bitmap(self, bitmap: wx.Bitmap, message: str) -> None: + """Write (in-place) a placeholder one-line message into a given bitmap. Used when no image is loaded yet.""" dc = wx.MemoryDC() dc.SelectObject(bitmap) dc.Clear() font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS)) dc.SetFont(font) - w, h = dc.GetTextExtent("Nothing yet!") - dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2) + w, h = dc.GetTextExtent(message) + dc.DrawText(message, (self.image_size - w) // 2, (self.image_size - - h) // 2) del dc - def get_current_pose(self): + def get_current_pose(self) -> List[float]: + """Get the current pose of the character as a list of morph values (in the order the models expect them). + + We do this by reading the values from the UI elements in the control panel. + """ current_pose = [0.0 for i in range(self.poser.get_num_parameters())] for morph_control_panel in self.morph_control_panels.values(): - morph_control_panel.set_param_value(current_pose) + morph_control_panel.write_to_pose(current_pose) for rotation_control_panel in self.non_morph_control_panels.values(): - rotation_control_panel.set_param_value(current_pose) + rotation_control_panel.write_to_pose(current_pose) return current_pose - def update_images(self, event: wx.Event): - current_pose = self.get_current_pose() - if not self.source_image_dirty \ - and self.last_pose is not None \ - and self.last_pose == current_pose \ - and self.last_output_index == self.output_index_choice.GetSelection(): + def set_current_pose(self, pose: List[float]) -> None: + """Write `pose` to the UI controls in the editor panel. + + Note that after this, you have to flush the wx event queue for the GUI to update itself correctly. + So if you call `set_current_pose` and intend to do something immediately, instead do that something + using `wx.CallAfter`. + """ + # `update_images` calls us; but if it is not already running (i.e. if we are called by something else), + # we should not let it run until the pose update is complete. + old_update_in_progress = self.update_in_progress + self.update_in_progress = True + try: + for panel in self.morph_control_panels.values(): + panel.read_from_pose(pose) + for panel in self.non_morph_control_panels.values(): + panel.read_from_pose(pose) + finally: + self.update_in_progress = old_update_in_progress + + def update_images(self, event: wx.Event) -> None: # This runs on a timer; keep the code as light as reasonably possible. + """Update the input and output images. + + The output image is rendered when necessary. + """ + # Though we're running in a single thread, the `wx.CallAfter` makes this concurrent, + # so the contents of this function should really be in a critical section. + # + # TODO: Atomic locking/mutex. + if self.update_in_progress: return - self.last_pose = current_pose - self.last_output_index = self.output_index_choice.GetSelection() + self.update_in_progress = True + last_update_time = time.time_ns() + actually_rendered = False # For the FPS counter, to detect if a render actually took place. - if self.torch_source_image is None: - self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap) - self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap) - self.source_image_dirty = False - self.Refresh() - self.Update() - return + # Apply the currently selected emotion, unless "[custom]" is selected, in which case skip this. + # Note this may modify the current pose, hence we do this first. + current_emotion_index = self.emotion_choice.GetSelection() + if current_emotion_index != 0 and current_emotion_index != self.last_emotion_index: # not "[custom]" + self.last_emotion_index = current_emotion_index + emotion_name = self.emotion_choice.GetString(current_emotion_index) + logger.info(f"Loading emotion preset {emotion_name}") + posedict = self.emotions[emotion_name] + pose = self.posedict_to_pose(posedict) + self.set_current_pose(pose) + current_pose = pose + else: + current_pose = self.get_current_pose() - if self.source_image_dirty: - dc = wx.MemoryDC() - dc.SelectObject(self.source_image_bitmap) - dc.Clear() - dc.DrawBitmap(self.wx_source_image, 0, 0) - self.source_image_dirty = False + # `wx.Slider.SetValue` needs to handle some events to update the visible thumb position, + # so we must defer the rest of our processing until currently pending events have been processed. + # + # https://forums.wxwidgets.org/viewtopic.php?t=47723 + # + # This code looks like JavaScript apps did before promises became a thing, essentially + # for the same reason. Manually spelling out async continuations is so 1990s, but: + # + # - These classical GUI toolkits were invented before the async/await syntax, so meh. + # - In a Lisp, we'd phrase this as something like `(wx-call-after-with (lambda: ...))` + # to have a clearer presentation order (we want to "call now the following thing...", + # not "here's a lengthy thing and by the way, call it now"), but Python doesn't have + # a proper lambda, so meh. + # + # Just keep in mind this "function" (technically, closure) is just a block of code + # to be run slightly later. + def update_images_cont() -> None: + try: + if not self.source_image_dirty \ + and self.last_pose is not None \ + and self.last_pose == current_pose \ + and self.last_output_index == self.output_index_choice.GetSelection(): + return + self.last_pose = current_pose + self.last_output_index = self.output_index_choice.GetSelection() - pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype) - output_index = self.output_index_choice.GetSelection() - with torch.no_grad(): - output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() + if self.torch_source_image is None: + self.draw_message_to_bitmap(self.source_image_bitmap, "[No image loaded]") + self.draw_message_to_bitmap(self.result_image_bitmap, "[No image loaded]") + self.source_image_dirty = False + return - numpy_image = convert_output_image_from_torch_to_numpy(output_image) - self.last_output_numpy_image = numpy_image - wx_image = wx.ImageFromBuffer( - numpy_image.shape[0], - numpy_image.shape[1], - numpy_image[:, :, 0:3].tobytes(), - numpy_image[:, :, 3].tobytes()) - wx_bitmap = wx_image.ConvertToBitmap() + if self.source_image_dirty: + dc = wx.MemoryDC() + dc.SelectObject(self.source_image_bitmap) + dc.Clear() + dc.DrawBitmap(self.wx_source_image, 0, 0) + self.source_image_dirty = False - dc = wx.MemoryDC() - dc.SelectObject(self.result_image_bitmap) - dc.Clear() - dc.DrawBitmap(wx_bitmap, - (self.image_size - numpy_image.shape[0]) // 2, - (self.image_size - numpy_image.shape[1]) // 2, - True) - del dc + pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype) + output_index = self.output_index_choice.GetSelection() + with torch.no_grad(): + output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu() - self.Refresh() - self.Update() - - def get_current_posedict(self): - # Your dictionary of keys - keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index'] - - # Get the current pose as a list of values - current_pose_values = self.get_current_pose() # replace this with the actual method or property that gets the pose values + numpy_image = convert_output_image_from_torch_to_numpy(output_image) + self.last_output_numpy_image = numpy_image + wx_image = wx.ImageFromBuffer( + numpy_image.shape[0], + numpy_image.shape[1], + numpy_image[:, :, 0:3].tobytes(), + numpy_image[:, :, 3].tobytes()) + wx_bitmap = wx_image.ConvertToBitmap() - # Create a dictionary by zipping together the keys and values - current_pose_dict = dict(zip(keys, current_pose_values)) + dc = wx.MemoryDC() + dc.SelectObject(self.result_image_bitmap) + dc.Clear() + dc.DrawBitmap(wx_bitmap, + (self.image_size - numpy_image.shape[0]) // 2, + (self.image_size - numpy_image.shape[1]) // 2, + True) + del dc + nonlocal actually_rendered + actually_rendered = True + finally: + # Set up another async continuation to finish things up. + # + # I have no idea why the final forced Refresh/Update must wait until other pending + # GUI events have been processed. When `update_images_cont` *starts*, the sliders + # should have been set to their final positions, and those events processed already. + # + # But for whatever reason, this fixes the remaining flakiness with the GUI element + # not visually updating when using `slider.SetValue`. + # + # Either I'm missing something important, or that's just GUI programming for you. + # + # Well, to look at the bright side, at least this gives us a place where we can + # compute the render FPS after the render is actually complete. + def update_images_cont2() -> None: + self.Refresh() + self.Update() + + # Update FPS counter, but only if a render actually took place (we want to measure the render speed only). + if actually_rendered: + elapsed_time = time.time_ns() - last_update_time + fps = 1.0 / (elapsed_time / 10**9) + if self.torch_source_image is not None: + self.fps_statistics.add_fps(fps) + self.fps_text.SetLabelText(f"Render: {self.fps_statistics.get_average_fps():0.2f} FPS") + + self.update_in_progress = False + wx.CallAfter(update_images_cont2) + wx.CallAfter(update_images_cont) + + def current_pose_to_posedict(self) -> Dict[str, float]: + """Convert the character's current pose into a posedict for saving into an emotion JSON.""" + current_pose_values = self.get_current_pose() + current_pose_dict = dict(zip(posedict_keys, current_pose_values)) return current_pose_dict - def on_save_image(self, event: wx.Event): + def posedict_to_pose(self, posedict: Dict[str, float]) -> List[float]: + """Convert a posedict (from an emotion JSON) into a list of morph values (in the order the models expect them).""" + # sanity check + unrecognized_keys = set(posedict.keys()) - set(posedict_keys) + if unrecognized_keys: + logger.warning(f"Ignoring unrecognized keys in posedict: {unrecognized_keys}") + # Missing keys are fine - keys for zero values can simply be omitted. + + pose = [0.0 for i in range(self.poser.get_num_parameters())] + for idx, key in enumerate(posedict_keys): + pose[idx] = posedict.get(key, 0.0) + return pose + + def on_save_image(self, event: wx.Event) -> None: + """Ask the user for destination and save the output image. + + The pose is automatically saved into the same directory as the output image, with + file name determined from the image file name (e.g. "my_emotion.png" -> "my_emotion.json"). + """ if self.last_output_numpy_image is None: - logging.info("There is no output image to save!!!") + logger.info("There is no output image to save.") + return + dir_name = "output" + file_dialog = wx.FileDialog(self, "Save output image", dir_name, "", "PNG images (*.png)|*.png", wx.FD_SAVE) + # try: # multi-format support: select PNG save format by default if available + # file_dialog.SetFilterIndex(output_ext_to_index["png"]) + # except Exception: + # pass + try: + if file_dialog.ShowModal() == wx.ID_OK: + image_file_name = file_dialog.GetFilename() + # idx = file_dialog.GetFilterIndex() + # ext = output_index_to_ext[idx] + # if ext and not image_file_name.lower().endswith(f".{ext}"): # usability: auto-add selected file extension + # image_file_name += f".{ext}" + if not image_file_name.lower().endswith(".png"): # usability: auto-add .png file extension + image_file_name += ".png" + + image_file_name = os.path.join(file_dialog.GetDirectory(), image_file_name) + try: + if os.path.exists(image_file_name): + message_dialog = wx.MessageDialog(self, f"Overwrite {image_file_name}?", "THA3 Manual Poser", + wx.YES_NO | wx.ICON_QUESTION) + try: + result = message_dialog.ShowModal() + if result == wx.ID_NO: + return + self.save_numpy_image(self.last_output_numpy_image, image_file_name) + finally: + message_dialog.Destroy() + else: + self.save_numpy_image(self.last_output_numpy_image, image_file_name) + + except Exception as exc: + logger.error(f"Could not save {image_file_name}, reason: {exc}") + message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}, reason: {exc}", "THA3 Manual Poser", wx.OK) + try: + message_dialog.ShowModal() + finally: + message_dialog.Destroy() + + else: # Since it is possible to save the image and JSON to "tha3/emotions", on a successful save, refresh the emotion presets list. + logger.info(f"Saved image {image_file_name}") + + current_emotion_old_index = self.emotion_choice.GetSelection() + current_emotion_name = self.emotion_choice.GetString(current_emotion_old_index) + + self.emotions, self.emotion_names = load_emotion_presets() + self.emotion_choice.SetItems(self.emotion_names) + + current_emotion_new_index = self.emotion_choice.FindString(current_emotion_name) + self.emotion_choice.SetSelection(current_emotion_new_index) + finally: + file_dialog.Destroy() + + def on_save_all_emotions(self, event: wx.Event) -> None: + """Ask the user for a destination directory, and batch save an output image using each of the emotion presets. + + Does not affect the output image displayed in the GUI. + """ + if self.torch_source_image is None: + logger.info("No image is loaded, nothing to batch.") return - #keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index'] - #current_pose_dict = dict(zip(keys, self.get_current_pose())) - #print(current_pose_dict) - # output settings to console. + dir_dialog = wx.DirDialog(self, "Choose directory to save in", "output", wx.DD_DEFAULT_STYLE) + try: + if dir_dialog.ShowModal() == wx.ID_OK: + dir_name = dir_dialog.GetPath() + if not os.path.exists(dir_name): + p = pathlib.Path(dir_name).expanduser().resolve() + pathlib.Path.mkdir(p, parents=True, exist_ok=True) + if os.listdir(dir_name): # not empty + # TODO: provide replace and merge modes + message_dialog = wx.MessageDialog(self, f"Directory is not empty: {dir_name}.\nAny files corresponding to emotion presets will be overwritten.\nProceed?", "THA3 Manual Poser", + wx.YES_NO | wx.ICON_QUESTION) + try: + result = message_dialog.ShowModal() + if result == wx.ID_NO: + return + finally: + message_dialog.Destroy() - dir_name = "data/images" - file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE) - if file_dialog.ShowModal() == wx.ID_OK: - image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename()) - try: - if os.path.exists(image_file_name): - message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser", - wx.YES_NO | wx.ICON_QUESTION) - result = message_dialog.ShowModal() - if result == wx.ID_YES: - self.save_last_numpy_image(image_file_name) - else: - self.save_last_numpy_image(image_file_name) + logger.info(f"Batch saving output based on all emotion presets to directory {dir_name}...") + for emotion_name, posedict in self.emotions.items(): + if emotion_name.startswith("[") and emotion_name.endswith("]"): + continue # skip "[custom]" and "[reset]" + try: + pose = self.posedict_to_pose(posedict) + posetensor = torch.tensor(pose, device=self.device, dtype=self.dtype) + output_index = self.output_index_choice.GetSelection() + with torch.no_grad(): + output_image = self.poser.pose(self.torch_source_image, posetensor, output_index)[0].detach().cpu() + numpy_image = convert_output_image_from_torch_to_numpy(output_image) - except: - message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK) - message_dialog.ShowModal() - message_dialog.Destroy() - file_dialog.Destroy() + image_file_name = os.path.join(dir_name, f"{emotion_name}.png") + self.save_numpy_image(numpy_image, image_file_name) - def save_last_numpy_image(self, image_file_name): - numpy_image = self.last_output_numpy_image - pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA') + logger.info(f"Saved image {image_file_name}") + except Exception as exc: + logger.error(f"Could not save {image_file_name}, reason: {exc}") + logger.info("Batch save finished.") + finally: + dir_dialog.Destroy() + + def save_numpy_image(self, numpy_image: numpy.array, image_file_name: str) -> None: + """Save the output image. + + Output format is determined by file extension (which must be supported by the installed `Pillow`). + Automatically save also the corresponding settings as JSON. + + The settings are saved into the same directory as the output image, with file name determined + from the image file name (e.g. "my_emotion.png" -> "my_emotion.json"). + """ + pil_image = PIL.Image.fromarray(numpy_image, mode="RGBA") os.makedirs(os.path.dirname(image_file_name), exist_ok=True) pil_image.save(image_file_name) - - - data_dict = self.get_current_posedict() # Get values - json_file_path = os.path.splitext(image_file_name)[0] + ".json" # Generate JSON file path + + data_dict = self.current_pose_to_posedict() + json_file_path = os.path.splitext(image_file_name)[0] + ".json" filename_without_extension = os.path.splitext(os.path.basename(image_file_name))[0] data_dict_with_filename = {filename_without_extension: data_dict} # Create a new dict with the filename as the key - with open(json_file_path, "w") as file: - json.dump(data_dict_with_filename, file, indent=4) - - + try: + with open(json_file_path, "w") as file: + json.dump(data_dict_with_filename, file, indent=4) + except Exception: + pass + else: + logger.info(f"Saved JSON {json_file_path}") if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Manually pose a character image.') - parser.add_argument( - '--model', - type=str, - required=False, - default='separable_float', - choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'], - help='The model to use.') + parser = argparse.ArgumentParser(description="THA 3 Manual Poser. Pose a character image manually. Useful for generating static expression images.") + parser.add_argument("--model", + type=str, + required=False, + default="separable_float", + choices=["standard_float", "separable_float", "standard_half", "separable_half"], + help="The model to use. 'float' means fp32, 'half' means fp16.") + parser.add_argument("--device", + type=str, + required=False, + default="cuda", + choices=["cpu", "cuda"], + help='The device to use for PyTorch ("cuda" for GPU, "cpu" for CPU).') args = parser.parse_args() - device = torch.device('cuda') try: - poser = load_poser(args.model, device) + device = torch.device(args.device) + poser = load_poser(args.model, device, modelsdir="tha3/models") except RuntimeError as e: - print(e) + logger.error(e) sys.exit() + # Create the "talkinghead/output" directory if it doesn't exist. This is our default save location. + p = pathlib.Path("output").expanduser().resolve() + pathlib.Path.mkdir(p, parents=True, exist_ok=True) + app = wx.App() - main_frame = MainFrame(poser, device) + main_frame = MainFrame(poser, device, args.model) main_frame.Show(True) main_frame.timer.Start(30) app.MainLoop() diff --git a/talkinghead/tha3/poser/modes/load_poser.py b/talkinghead/tha3/poser/modes/load_poser.py index a4f371a..f5c88c2 100644 --- a/talkinghead/tha3/poser/modes/load_poser.py +++ b/talkinghead/tha3/poser/modes/load_poser.py @@ -1,19 +1,19 @@ import torch -def load_poser(model: str, device: torch.device): +def load_poser(model: str, device: torch.device, modelsdir="talkinghead/tha3/models"): print("Using the %s model." % model) if model == "standard_float": from tha3.poser.modes.standard_float import create_poser - return create_poser(device) + return create_poser(device, modelsdir=modelsdir) elif model == "standard_half": from tha3.poser.modes.standard_half import create_poser - return create_poser(device) + return create_poser(device, modelsdir=modelsdir) elif model == "separable_float": from tha3.poser.modes.separable_float import create_poser - return create_poser(device) + return create_poser(device, modelsdir=modelsdir) elif model == "separable_half": from tha3.poser.modes.separable_half import create_poser - return create_poser(device) + return create_poser(device, modelsdir=modelsdir) else: - raise RuntimeError("Invalid model: '%s'" % model) \ No newline at end of file + raise RuntimeError("Invalid model: '%s'" % model) diff --git a/talkinghead/tha3/poser/modes/pose_parameters.py b/talkinghead/tha3/poser/modes/pose_parameters.py index 433a68d..bef82b6 100644 --- a/talkinghead/tha3/poser/modes/pose_parameters.py +++ b/talkinghead/tha3/poser/modes/pose_parameters.py @@ -33,4 +33,4 @@ def get_pose_parameters(): .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ - .build() \ No newline at end of file + .build() diff --git a/talkinghead/tha3/poser/modes/separable_float.py b/talkinghead/tha3/poser/modes/separable_float.py index 4fd07cc..7769b26 100644 --- a/talkinghead/tha3/poser/modes/separable_float.py +++ b/talkinghead/tha3/poser/modes/separable_float.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import Dict, Optional, List +import os +from typing import Dict, List, Optional import torch from torch import Tensor @@ -259,30 +260,31 @@ def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, - default_output_index: int = 0) -> GeneralPoser02: + default_output_index: int = 0, + modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_float" - file_name = dir + "/eyebrow_decomposer.pt" + file_name = os.path.join(modelsdir, "separable_float", "eyebrow_decomposer.pt") module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_float" - file_name = dir + "/eyebrow_morphing_combiner.pt" + file_name = os.path.join(modelsdir, "separable_float", "eyebrow_morphing_combiner.pt") module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_float" - file_name = dir + "/face_morpher.pt" + file_name = os.path.join(modelsdir, "separable_float", "face_morpher.pt") module_file_names[Network.face_morpher.name] = file_name if Network.two_algo_face_body_rotator.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_float" - file_name = dir + "/two_algo_face_body_rotator.pt" + file_name = os.path.join(modelsdir, "separable_float", "two_algo_face_body_rotator.pt") module_file_names[Network.two_algo_face_body_rotator.name] = file_name if Network.editor.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_float" - file_name = dir + "/editor.pt" + file_name = os.path.join(modelsdir, "separable_float", "editor.pt") module_file_names[Network.editor.name] = file_name + # fail-fast + for file_name in module_file_names.values(): + if not os.path.exists(file_name): + raise FileNotFoundError(f"Model file {file_name} not found, please check the path.") + loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), @@ -328,4 +330,4 @@ if __name__ == "__main__": print("%d:" % i, elapsed_time) acc = acc + elapsed_time - print("average:", acc / repeat) \ No newline at end of file + print("average:", acc / repeat) diff --git a/talkinghead/tha3/poser/modes/separable_half.py b/talkinghead/tha3/poser/modes/separable_half.py index 24d29ed..1e887a6 100644 --- a/talkinghead/tha3/poser/modes/separable_half.py +++ b/talkinghead/tha3/poser/modes/separable_half.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List, Dict, Optional +import os +from typing import Dict, List, Optional import torch from torch import Tensor @@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_03 import \ EyebrowMorphingCombiner03Factory, EyebrowMorphingCombiner03Args, EyebrowMorphingCombiner03 from tha3.nn.face_morpher.face_morpher_09 import FaceMorpher09Factory, FaceMorpher09Args from tha3.poser.general_poser_02 import GeneralPoser02 -from tha3.poser.poser import PoseParameterCategory, PoseParameters from tha3.nn.editor.editor_07 import Editor07, Editor07Args from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ TwoAlgoFaceBodyRotator05Args +from tha3.poser.modes.pose_parameters import get_pose_parameters from tha3.util import torch_load from tha3.compute.cached_computation_func import TensorListCachedComputationFunc from tha3.compute.cached_computation_protocol import CachedComputationProtocol @@ -253,69 +254,35 @@ def load_editor(file_name) -> Module: return module -def get_pose_parameters(): - return PoseParameters.Builder() \ - .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ - .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ - .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ - .build() - - def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner03.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, - default_output_index: int = 0) -> GeneralPoser02: + default_output_index: int = 0, + modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_half" - file_name = dir + "/eyebrow_decomposer.pt" + file_name = os.path.join(modelsdir, "separable_half", "eyebrow_decomposer.pt") module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_half" - file_name = dir + "/eyebrow_morphing_combiner.pt" + file_name = os.path.join(modelsdir, "separable_half", "eyebrow_morphing_combiner.pt") module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_half" - file_name = dir + "/face_morpher.pt" + file_name = os.path.join(modelsdir, "separable_half", "face_morpher.pt") module_file_names[Network.face_morpher.name] = file_name if Network.two_algo_face_body_rotator.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_half" - file_name = dir + "/two_algo_face_body_rotator.pt" + file_name = os.path.join(modelsdir, "separable_half", "two_algo_face_body_rotator.pt") module_file_names[Network.two_algo_face_body_rotator.name] = file_name if Network.editor.name not in module_file_names: - dir = "talkinghead/tha3/models/separable_half" - file_name = dir + "/editor.pt" + file_name = os.path.join(modelsdir, "separable_half", "editor.pt") module_file_names[Network.editor.name] = file_name + # fail-fast + for file_name in module_file_names.values(): + if not os.path.exists(file_name): + raise FileNotFoundError(f"Model file {file_name} not found, please check the path.") + loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), diff --git a/talkinghead/tha3/poser/modes/standard_float.py b/talkinghead/tha3/poser/modes/standard_float.py index f75008f..15d686f 100644 --- a/talkinghead/tha3/poser/modes/standard_float.py +++ b/talkinghead/tha3/poser/modes/standard_float.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List, Dict, Optional +import os +from typing import Dict, List, Optional import torch from torch import Tensor @@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory from tha3.poser.general_poser_02 import GeneralPoser02 -from tha3.poser.poser import PoseParameterCategory, PoseParameters from tha3.nn.editor.editor_07 import Editor07, Editor07Args from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ TwoAlgoFaceBodyRotator05Args +from tha3.poser.modes.pose_parameters import get_pose_parameters from tha3.util import torch_load from tha3.compute.cached_computation_func import TensorListCachedComputationFunc from tha3.compute.cached_computation_protocol import CachedComputationProtocol @@ -251,69 +252,35 @@ def load_editor(file_name) -> Module: return module -def get_pose_parameters(): - return PoseParameters.Builder() \ - .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ - .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ - .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ - .build() - - def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, - default_output_index: int = 0) -> GeneralPoser02: + default_output_index: int = 0, + modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_float" - file_name = dir + "/eyebrow_decomposer.pt" + file_name = os.path.join(modelsdir, "standard_float", "eyebrow_decomposer.pt") module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_float" - file_name = dir + "/eyebrow_morphing_combiner.pt" + file_name = os.path.join(modelsdir, "standard_float", "eyebrow_morphing_combiner.pt") module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_float" - file_name = dir + "/face_morpher.pt" + file_name = os.path.join(modelsdir, "standard_float", "face_morpher.pt") module_file_names[Network.face_morpher.name] = file_name if Network.two_algo_face_body_rotator.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_float" - file_name = dir + "/two_algo_face_body_rotator.pt" + file_name = os.path.join(modelsdir, "standard_float", "two_algo_face_body_rotator.pt") module_file_names[Network.two_algo_face_body_rotator.name] = file_name if Network.editor.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_float" - file_name = dir + "/editor.pt" + file_name = os.path.join(modelsdir, "standard_float", "editor.pt") module_file_names[Network.editor.name] = file_name + # fail-fast + for file_name in module_file_names.values(): + if not os.path.exists(file_name): + raise FileNotFoundError(f"Model file {file_name} not found, please check the path.") + loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]), diff --git a/talkinghead/tha3/poser/modes/standard_half.py b/talkinghead/tha3/poser/modes/standard_half.py index 5d28bc3..fb677c6 100644 --- a/talkinghead/tha3/poser/modes/standard_half.py +++ b/talkinghead/tha3/poser/modes/standard_half.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List, Dict, Optional +import os +from typing import Dict, List, Optional import torch from torch import Tensor @@ -12,10 +13,10 @@ from tha3.nn.eyebrow_morphing_combiner.eyebrow_morphing_combiner_00 import \ EyebrowMorphingCombiner00Factory, EyebrowMorphingCombiner00Args, EyebrowMorphingCombiner00 from tha3.nn.face_morpher.face_morpher_08 import FaceMorpher08Args, FaceMorpher08Factory from tha3.poser.general_poser_02 import GeneralPoser02 -from tha3.poser.poser import PoseParameterCategory, PoseParameters from tha3.nn.editor.editor_07 import Editor07, Editor07Args from tha3.nn.two_algo_body_rotator.two_algo_face_body_rotator_05 import TwoAlgoFaceBodyRotator05, \ TwoAlgoFaceBodyRotator05Args +from tha3.poser.modes.pose_parameters import get_pose_parameters from tha3.util import torch_load from tha3.compute.cached_computation_func import TensorListCachedComputationFunc from tha3.compute.cached_computation_protocol import CachedComputationProtocol @@ -251,69 +252,35 @@ def load_editor(file_name) -> Module: return module -def get_pose_parameters(): - return PoseParameters.Builder() \ - .add_parameter_group("eyebrow_troubled", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_angry", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_lowered", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_raised", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_happy", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eyebrow_serious", PoseParameterCategory.EYEBROW, arity=2) \ - .add_parameter_group("eye_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_happy_wink", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_surprised", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_relaxed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_unimpressed", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("eye_raised_lower_eyelid", PoseParameterCategory.EYE, arity=2) \ - .add_parameter_group("iris_small", PoseParameterCategory.IRIS_MORPH, arity=2) \ - .add_parameter_group("mouth_aaa", PoseParameterCategory.MOUTH, arity=1, default_value=1.0) \ - .add_parameter_group("mouth_iii", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_uuu", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_eee", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_ooo", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_delta", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("mouth_lowered_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_raised_corner", PoseParameterCategory.MOUTH, arity=2) \ - .add_parameter_group("mouth_smirk", PoseParameterCategory.MOUTH, arity=1) \ - .add_parameter_group("iris_rotation_x", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("iris_rotation_y", PoseParameterCategory.IRIS_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_x", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("head_y", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("neck_z", PoseParameterCategory.FACE_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_y", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("body_z", PoseParameterCategory.BODY_ROTATION, arity=1, range=(-1.0, 1.0)) \ - .add_parameter_group("breathing", PoseParameterCategory.BREATHING, arity=1, range=(0.0, 1.0)) \ - .build() - - def create_poser( device: torch.device, module_file_names: Optional[Dict[str, str]] = None, eyebrow_morphed_image_index: int = EyebrowMorphingCombiner00.EYEBROW_IMAGE_NO_COMBINE_ALPHA_INDEX, - default_output_index: int = 0) -> GeneralPoser02: + default_output_index: int = 0, + modelsdir: str = "talkinghead/tha3/models") -> GeneralPoser02: if module_file_names is None: module_file_names = {} if Network.eyebrow_decomposer.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_half" - file_name = dir + "/eyebrow_decomposer.pt" + file_name = os.path.join(modelsdir, "standard_half", "eyebrow_decomposer.pt") module_file_names[Network.eyebrow_decomposer.name] = file_name if Network.eyebrow_morphing_combiner.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_half" - file_name = dir + "/eyebrow_morphing_combiner.pt" + file_name = os.path.join(modelsdir, "standard_half", "eyebrow_morphing_combiner.pt") module_file_names[Network.eyebrow_morphing_combiner.name] = file_name if Network.face_morpher.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_half" - file_name = dir + "/face_morpher.pt" + file_name = os.path.join(modelsdir, "standard_half", "face_morpher.pt") module_file_names[Network.face_morpher.name] = file_name if Network.two_algo_face_body_rotator.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_half" - file_name = dir + "/two_algo_face_body_rotator.pt" + file_name = os.path.join(modelsdir, "standard_half", "two_algo_face_body_rotator.pt") module_file_names[Network.two_algo_face_body_rotator.name] = file_name if Network.editor.name not in module_file_names: - dir = "talkinghead/tha3/models/standard_half" - file_name = dir + "/editor.pt" + file_name = os.path.join(modelsdir, "standard_half", "editor.pt") module_file_names[Network.editor.name] = file_name + # fail-fast + for file_name in module_file_names.values(): + if not os.path.exists(file_name): + raise FileNotFoundError(f"Model file {file_name} not found, please check the path.") + loaders = { Network.eyebrow_decomposer.name: lambda: load_eyebrow_decomposer(module_file_names[Network.eyebrow_decomposer.name]),