mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-28 18:31:19 +00:00
501 lines
23 KiB
Python
501 lines
23 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import sys
|
|
import PIL.Image
|
|
import numpy
|
|
import torch
|
|
import wx
|
|
import json
|
|
from typing import List
|
|
|
|
# Set the working directory to the "live2d" subdirectory to work with file structure
|
|
target_directory = os.path.join(os.getcwd(), "live2d")
|
|
os.chdir(target_directory)
|
|
sys.path.append(os.getcwd())
|
|
|
|
from tha3.poser.modes.load_poser import load_poser
|
|
from tha3.poser.poser import Poser, PoseParameterCategory, PoseParameterGroup
|
|
from tha3.util import extract_pytorch_image_from_filelike, rgba_to_numpy_image, grid_change_to_numpy_image, \
|
|
rgb_to_numpy_image, resize_PIL_image, extract_PIL_image_from_filelike, extract_pytorch_image_from_PIL_image
|
|
|
|
current_directory = os.getcwd()
|
|
parent_directory = os.path.dirname(current_directory)
|
|
os.chdir(parent_directory)
|
|
|
|
class MorphCategoryControlPanel(wx.Panel):
|
|
def __init__(self,
|
|
parent,
|
|
title: str,
|
|
pose_param_category: PoseParameterCategory,
|
|
param_groups: List[PoseParameterGroup]):
|
|
super().__init__(parent, style=wx.SIMPLE_BORDER)
|
|
self.pose_param_category = pose_param_category
|
|
self.sizer = wx.BoxSizer(wx.VERTICAL)
|
|
self.SetSizer(self.sizer)
|
|
self.SetAutoLayout(1)
|
|
|
|
title_text = wx.StaticText(self, label=title, style=wx.ALIGN_CENTER)
|
|
self.sizer.Add(title_text, 0, wx.EXPAND)
|
|
|
|
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
|
|
self.choice = wx.Choice(self, choices=[group.get_group_name() for group in self.param_groups])
|
|
if len(self.param_groups) > 0:
|
|
self.choice.SetSelection(0)
|
|
self.choice.Bind(wx.EVT_CHOICE, self.on_choice_updated)
|
|
self.sizer.Add(self.choice, 0, wx.EXPAND)
|
|
|
|
self.left_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
|
|
self.sizer.Add(self.left_slider, 0, wx.EXPAND)
|
|
|
|
self.right_slider = wx.Slider(self, minValue=-1000, maxValue=1000, value=-1000, style=wx.HORIZONTAL)
|
|
self.sizer.Add(self.right_slider, 0, wx.EXPAND)
|
|
|
|
self.checkbox = wx.CheckBox(self, label="Show")
|
|
self.checkbox.SetValue(True)
|
|
self.sizer.Add(self.checkbox, 0, wx.SHAPED | wx.ALIGN_CENTER)
|
|
|
|
self.update_ui()
|
|
|
|
self.sizer.Fit(self)
|
|
|
|
def update_ui(self):
|
|
param_group = self.param_groups[self.choice.GetSelection()]
|
|
if param_group.is_discrete():
|
|
self.left_slider.Enable(False)
|
|
self.right_slider.Enable(False)
|
|
self.checkbox.Enable(True)
|
|
elif param_group.get_arity() == 1:
|
|
self.left_slider.Enable(True)
|
|
self.right_slider.Enable(False)
|
|
self.checkbox.Enable(False)
|
|
else:
|
|
self.left_slider.Enable(True)
|
|
self.right_slider.Enable(True)
|
|
self.checkbox.Enable(False)
|
|
|
|
def on_choice_updated(self, event: wx.Event):
|
|
param_group = self.param_groups[self.choice.GetSelection()]
|
|
if param_group.is_discrete():
|
|
self.checkbox.SetValue(True)
|
|
self.update_ui()
|
|
|
|
def set_param_value(self, pose: List[float]):
|
|
if len(self.param_groups) == 0:
|
|
return
|
|
selected_morph_index = self.choice.GetSelection()
|
|
param_group = self.param_groups[selected_morph_index]
|
|
param_index = param_group.get_parameter_index()
|
|
if param_group.is_discrete():
|
|
if self.checkbox.GetValue():
|
|
for i in range(param_group.get_arity()):
|
|
pose[param_index + i] = 1.0
|
|
else:
|
|
param_range = param_group.get_range()
|
|
alpha = (self.left_slider.GetValue() + 1000) / 2000.0
|
|
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
|
|
if param_group.get_arity() == 2:
|
|
alpha = (self.right_slider.GetValue() + 1000) / 2000.0
|
|
pose[param_index + 1] = param_range[0] + (param_range[1] - param_range[0]) * alpha
|
|
|
|
|
|
class SimpleParamGroupsControlPanel(wx.Panel):
|
|
def __init__(self, parent,
|
|
pose_param_category: PoseParameterCategory,
|
|
param_groups: List[PoseParameterGroup]):
|
|
super().__init__(parent, style=wx.SIMPLE_BORDER)
|
|
self.sizer = wx.BoxSizer(wx.VERTICAL)
|
|
self.SetSizer(self.sizer)
|
|
self.SetAutoLayout(1)
|
|
|
|
self.param_groups = [group for group in param_groups if group.get_category() == pose_param_category]
|
|
for param_group in self.param_groups:
|
|
assert not param_group.is_discrete()
|
|
assert param_group.get_arity() == 1
|
|
|
|
self.sliders = []
|
|
for param_group in self.param_groups:
|
|
static_text = wx.StaticText(
|
|
self,
|
|
label=" ------------ %s ------------ " % param_group.get_group_name(), style=wx.ALIGN_CENTER)
|
|
self.sizer.Add(static_text, 0, wx.EXPAND)
|
|
range = param_group.get_range()
|
|
min_value = int(range[0] * 1000)
|
|
max_value = int(range[1] * 1000)
|
|
slider = wx.Slider(self, minValue=min_value, maxValue=max_value, value=0, style=wx.HORIZONTAL)
|
|
self.sizer.Add(slider, 0, wx.EXPAND)
|
|
self.sliders.append(slider)
|
|
|
|
self.sizer.Fit(self)
|
|
|
|
def set_param_value(self, pose: List[float]):
|
|
if len(self.param_groups) == 0:
|
|
return
|
|
for param_group_index in range(len(self.param_groups)):
|
|
param_group = self.param_groups[param_group_index]
|
|
slider = self.sliders[param_group_index]
|
|
param_range = param_group.get_range()
|
|
param_index = param_group.get_parameter_index()
|
|
alpha = (slider.GetValue() - slider.GetMin()) * 1.0 / (slider.GetMax() - slider.GetMin())
|
|
pose[param_index] = param_range[0] + (param_range[1] - param_range[0]) * alpha
|
|
|
|
|
|
def convert_output_image_from_torch_to_numpy(output_image):
|
|
if output_image.shape[2] == 2:
|
|
h, w, c = output_image.shape
|
|
numpy_image = torch.transpose(output_image.reshape(h * w, c), 0, 1).reshape(c, h, w)
|
|
elif output_image.shape[0] == 4:
|
|
numpy_image = rgba_to_numpy_image(output_image)
|
|
elif output_image.shape[0] == 3:
|
|
numpy_image = rgb_to_numpy_image(output_image)
|
|
elif output_image.shape[0] == 1:
|
|
c, h, w = output_image.shape
|
|
alpha_image = torch.cat([output_image.repeat(3, 1, 1) * 2.0 - 1.0, torch.ones(1, h, w)], dim=0)
|
|
numpy_image = rgba_to_numpy_image(alpha_image)
|
|
elif output_image.shape[0] == 2:
|
|
numpy_image = grid_change_to_numpy_image(output_image, num_channels=4)
|
|
else:
|
|
raise RuntimeError("Unsupported # image channels: %d" % output_image.shape[0])
|
|
numpy_image = numpy.uint8(numpy.rint(numpy_image * 255.0))
|
|
return numpy_image
|
|
|
|
|
|
class MainFrame(wx.Frame):
|
|
def __init__(self, poser: Poser, device: torch.device):
|
|
super().__init__(None, wx.ID_ANY, "Poser")
|
|
self.poser = poser
|
|
self.dtype = self.poser.get_dtype()
|
|
self.device = device
|
|
self.image_size = self.poser.get_image_size()
|
|
|
|
self.wx_source_image = None
|
|
self.torch_source_image = None
|
|
|
|
self.main_sizer = wx.BoxSizer(wx.HORIZONTAL)
|
|
self.SetSizer(self.main_sizer)
|
|
self.SetAutoLayout(1)
|
|
self.init_left_panel()
|
|
self.init_control_panel()
|
|
self.init_right_panel()
|
|
self.main_sizer.Fit(self)
|
|
|
|
self.timer = wx.Timer(self, wx.ID_ANY)
|
|
self.Bind(wx.EVT_TIMER, self.update_images, self.timer)
|
|
|
|
save_image_id = wx.NewIdRef()
|
|
self.Bind(wx.EVT_MENU, self.on_save_image, id=save_image_id)
|
|
accelerator_table = wx.AcceleratorTable([
|
|
(wx.ACCEL_CTRL, ord('S'), save_image_id)
|
|
])
|
|
self.SetAcceleratorTable(accelerator_table)
|
|
|
|
self.last_pose = None
|
|
self.last_output_index = self.output_index_choice.GetSelection()
|
|
self.last_output_numpy_image = None
|
|
|
|
self.wx_source_image = None
|
|
self.torch_source_image = None
|
|
self.source_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
|
|
self.result_image_bitmap = wx.Bitmap(self.image_size, self.image_size)
|
|
self.source_image_dirty = True
|
|
|
|
def init_left_panel(self):
|
|
self.control_panel = wx.Panel(self, style=wx.SIMPLE_BORDER, size=(self.image_size, -1))
|
|
self.left_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
|
|
left_panel_sizer = wx.BoxSizer(wx.VERTICAL)
|
|
self.left_panel.SetSizer(left_panel_sizer)
|
|
self.left_panel.SetAutoLayout(1)
|
|
|
|
self.source_image_panel = wx.Panel(self.left_panel, size=(self.image_size, self.image_size),
|
|
style=wx.SIMPLE_BORDER)
|
|
self.source_image_panel.Bind(wx.EVT_PAINT, self.paint_source_image_panel)
|
|
self.source_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
|
|
left_panel_sizer.Add(self.source_image_panel, 0, wx.FIXED_MINSIZE)
|
|
|
|
self.load_image_button = wx.Button(self.left_panel, wx.ID_ANY, "\nLoad Image\n\n")
|
|
left_panel_sizer.Add(self.load_image_button, 1, wx.EXPAND)
|
|
self.load_image_button.Bind(wx.EVT_BUTTON, self.load_image)
|
|
|
|
left_panel_sizer.Fit(self.left_panel)
|
|
self.main_sizer.Add(self.left_panel, 0, wx.FIXED_MINSIZE)
|
|
|
|
def on_erase_background(self, event: wx.Event):
|
|
pass
|
|
|
|
def init_control_panel(self):
|
|
self.control_panel_sizer = wx.BoxSizer(wx.VERTICAL)
|
|
self.control_panel.SetSizer(self.control_panel_sizer)
|
|
self.control_panel.SetMinSize(wx.Size(256, 1))
|
|
|
|
morph_categories = [
|
|
PoseParameterCategory.EYEBROW,
|
|
PoseParameterCategory.EYE,
|
|
PoseParameterCategory.MOUTH,
|
|
PoseParameterCategory.IRIS_MORPH
|
|
]
|
|
morph_category_titles = {
|
|
PoseParameterCategory.EYEBROW: " ------------ Eyebrow ------------ ",
|
|
PoseParameterCategory.EYE: " ------------ Eye ------------ ",
|
|
PoseParameterCategory.MOUTH: " ------------ Mouth ------------ ",
|
|
PoseParameterCategory.IRIS_MORPH: " ------------ Iris morphs ------------ ",
|
|
}
|
|
self.morph_control_panels = {}
|
|
for category in morph_categories:
|
|
param_groups = self.poser.get_pose_parameter_groups()
|
|
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
|
|
if len(filtered_param_groups) == 0:
|
|
continue
|
|
control_panel = MorphCategoryControlPanel(
|
|
self.control_panel,
|
|
morph_category_titles[category],
|
|
category,
|
|
self.poser.get_pose_parameter_groups())
|
|
self.morph_control_panels[category] = control_panel
|
|
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
|
|
|
|
self.non_morph_control_panels = {}
|
|
non_morph_categories = [
|
|
PoseParameterCategory.IRIS_ROTATION,
|
|
PoseParameterCategory.FACE_ROTATION,
|
|
PoseParameterCategory.BODY_ROTATION,
|
|
PoseParameterCategory.BREATHING
|
|
]
|
|
for category in non_morph_categories:
|
|
param_groups = self.poser.get_pose_parameter_groups()
|
|
filtered_param_groups = [group for group in param_groups if group.get_category() == category]
|
|
if len(filtered_param_groups) == 0:
|
|
continue
|
|
control_panel = SimpleParamGroupsControlPanel(
|
|
self.control_panel,
|
|
category,
|
|
self.poser.get_pose_parameter_groups())
|
|
self.non_morph_control_panels[category] = control_panel
|
|
self.control_panel_sizer.Add(control_panel, 0, wx.EXPAND)
|
|
|
|
self.control_panel_sizer.Fit(self.control_panel)
|
|
self.main_sizer.Add(self.control_panel, 1, wx.FIXED_MINSIZE)
|
|
|
|
def init_right_panel(self):
|
|
self.right_panel = wx.Panel(self, style=wx.SIMPLE_BORDER)
|
|
right_panel_sizer = wx.BoxSizer(wx.VERTICAL)
|
|
self.right_panel.SetSizer(right_panel_sizer)
|
|
self.right_panel.SetAutoLayout(1)
|
|
|
|
self.result_image_panel = wx.Panel(self.right_panel,
|
|
size=(self.image_size, self.image_size),
|
|
style=wx.SIMPLE_BORDER)
|
|
self.result_image_panel.Bind(wx.EVT_PAINT, self.paint_result_image_panel)
|
|
self.result_image_panel.Bind(wx.EVT_ERASE_BACKGROUND, self.on_erase_background)
|
|
self.output_index_choice = wx.Choice(
|
|
self.right_panel,
|
|
choices=[str(i) for i in range(self.poser.get_output_length())])
|
|
self.output_index_choice.SetSelection(0)
|
|
right_panel_sizer.Add(self.result_image_panel, 0, wx.FIXED_MINSIZE)
|
|
right_panel_sizer.Add(self.output_index_choice, 0, wx.EXPAND)
|
|
|
|
self.save_image_button = wx.Button(self.right_panel, wx.ID_ANY, "\nSave Image\n\n")
|
|
right_panel_sizer.Add(self.save_image_button, 1, wx.EXPAND)
|
|
self.save_image_button.Bind(wx.EVT_BUTTON, self.on_save_image)
|
|
|
|
right_panel_sizer.Fit(self.right_panel)
|
|
self.main_sizer.Add(self.right_panel, 0, wx.FIXED_MINSIZE)
|
|
|
|
def create_param_category_choice(self, param_category: PoseParameterCategory):
|
|
params = []
|
|
for param_group in self.poser.get_pose_parameter_groups():
|
|
if param_group.get_category() == param_category:
|
|
params.append(param_group.get_group_name())
|
|
choice = wx.Choice(self.control_panel, choices=params)
|
|
if len(params) > 0:
|
|
choice.SetSelection(0)
|
|
return choice
|
|
|
|
def load_image(self, event: wx.Event):
|
|
dir_name = "data/images"
|
|
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_OPEN)
|
|
if file_dialog.ShowModal() == wx.ID_OK:
|
|
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
|
|
try:
|
|
pil_image = resize_PIL_image(extract_PIL_image_from_filelike(image_file_name),
|
|
(self.poser.get_image_size(), self.poser.get_image_size()))
|
|
w, h = pil_image.size
|
|
if pil_image.mode != 'RGBA':
|
|
self.source_image_string = "Image must have alpha channel!"
|
|
self.wx_source_image = None
|
|
self.torch_source_image = None
|
|
else:
|
|
self.wx_source_image = wx.Bitmap.FromBufferRGBA(w, h, pil_image.convert("RGBA").tobytes())
|
|
self.torch_source_image = extract_pytorch_image_from_PIL_image(pil_image)\
|
|
.to(self.device).to(self.dtype)
|
|
self.source_image_dirty = True
|
|
self.Refresh()
|
|
self.Update()
|
|
except:
|
|
message_dialog = wx.MessageDialog(self, "Could not load image " + image_file_name, "Poser", wx.OK)
|
|
message_dialog.ShowModal()
|
|
message_dialog.Destroy()
|
|
file_dialog.Destroy()
|
|
|
|
def paint_source_image_panel(self, event: wx.Event):
|
|
wx.BufferedPaintDC(self.source_image_panel, self.source_image_bitmap)
|
|
|
|
def paint_result_image_panel(self, event: wx.Event):
|
|
wx.BufferedPaintDC(self.result_image_panel, self.result_image_bitmap)
|
|
|
|
def draw_nothing_yet_string_to_bitmap(self, bitmap):
|
|
dc = wx.MemoryDC()
|
|
dc.SelectObject(bitmap)
|
|
|
|
dc.Clear()
|
|
font = wx.Font(wx.FontInfo(14).Family(wx.FONTFAMILY_SWISS))
|
|
dc.SetFont(font)
|
|
w, h = dc.GetTextExtent("Nothing yet!")
|
|
dc.DrawText("Nothing yet!", (self.image_size - w) // 2, (self.image_size - - h) // 2)
|
|
|
|
del dc
|
|
|
|
def get_current_pose(self):
|
|
current_pose = [0.0 for i in range(self.poser.get_num_parameters())]
|
|
for morph_control_panel in self.morph_control_panels.values():
|
|
morph_control_panel.set_param_value(current_pose)
|
|
for rotation_control_panel in self.non_morph_control_panels.values():
|
|
rotation_control_panel.set_param_value(current_pose)
|
|
return current_pose
|
|
|
|
def update_images(self, event: wx.Event):
|
|
current_pose = self.get_current_pose()
|
|
if not self.source_image_dirty \
|
|
and self.last_pose is not None \
|
|
and self.last_pose == current_pose \
|
|
and self.last_output_index == self.output_index_choice.GetSelection():
|
|
return
|
|
self.last_pose = current_pose
|
|
self.last_output_index = self.output_index_choice.GetSelection()
|
|
|
|
if self.torch_source_image is None:
|
|
self.draw_nothing_yet_string_to_bitmap(self.source_image_bitmap)
|
|
self.draw_nothing_yet_string_to_bitmap(self.result_image_bitmap)
|
|
self.source_image_dirty = False
|
|
self.Refresh()
|
|
self.Update()
|
|
return
|
|
|
|
if self.source_image_dirty:
|
|
dc = wx.MemoryDC()
|
|
dc.SelectObject(self.source_image_bitmap)
|
|
dc.Clear()
|
|
dc.DrawBitmap(self.wx_source_image, 0, 0)
|
|
self.source_image_dirty = False
|
|
|
|
pose = torch.tensor(current_pose, device=self.device, dtype=self.dtype)
|
|
output_index = self.output_index_choice.GetSelection()
|
|
with torch.no_grad():
|
|
output_image = self.poser.pose(self.torch_source_image, pose, output_index)[0].detach().cpu()
|
|
|
|
numpy_image = convert_output_image_from_torch_to_numpy(output_image)
|
|
self.last_output_numpy_image = numpy_image
|
|
wx_image = wx.ImageFromBuffer(
|
|
numpy_image.shape[0],
|
|
numpy_image.shape[1],
|
|
numpy_image[:, :, 0:3].tobytes(),
|
|
numpy_image[:, :, 3].tobytes())
|
|
wx_bitmap = wx_image.ConvertToBitmap()
|
|
|
|
dc = wx.MemoryDC()
|
|
dc.SelectObject(self.result_image_bitmap)
|
|
dc.Clear()
|
|
dc.DrawBitmap(wx_bitmap,
|
|
(self.image_size - numpy_image.shape[0]) // 2,
|
|
(self.image_size - numpy_image.shape[1]) // 2,
|
|
True)
|
|
del dc
|
|
|
|
self.Refresh()
|
|
self.Update()
|
|
|
|
def get_current_posedict(self):
|
|
# Your dictionary of keys
|
|
keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index']
|
|
|
|
# Get the current pose as a list of values
|
|
current_pose_values = self.get_current_pose() # replace this with the actual method or property that gets the pose values
|
|
|
|
# Create a dictionary by zipping together the keys and values
|
|
current_pose_dict = dict(zip(keys, current_pose_values))
|
|
|
|
return current_pose_dict
|
|
|
|
def on_save_image(self, event: wx.Event):
|
|
if self.last_output_numpy_image is None:
|
|
logging.info("There is no output image to save!!!")
|
|
return
|
|
|
|
#keys = ['eyebrow_troubled_left_index', 'eyebrow_troubled_right_index', 'eyebrow_angry_left_index', 'eyebrow_angry_right_index', 'eyebrow_lowered_left_index', 'eyebrow_lowered_right_index', 'eyebrow_raised_left_index', 'eyebrow_raised_right_index', 'eyebrow_happy_left_index', 'eyebrow_happy_right_index', 'eyebrow_serious_left_index', 'eyebrow_serious_right_index', 'eye_wink_left_index', 'eye_wink_right_index', 'eye_happy_wink_left_index', 'eye_happy_wink_right_index', 'eye_surprised_left_index', 'eye_surprised_right_index', 'eye_relaxed_left_index', 'eye_relaxed_right_index', 'eye_unimpressed', 'eye_unimpressed', 'eye_raised_lower_eyelid_left_index', 'eye_raised_lower_eyelid_right_index', 'iris_small_left_index', 'iris_small_right_index', 'mouth_aaa_index', 'mouth_iii_index', 'mouth_uuu_index', 'mouth_eee_index', 'mouth_ooo_index', 'mouth_delta', 'mouth_lowered_corner_left_index', 'mouth_lowered_corner_right_index', 'mouth_raised_corner_left_index', 'mouth_raised_corner_right_index', 'mouth_smirk', 'iris_rotation_x_index', 'iris_rotation_y_index', 'head_x_index', 'head_y_index', 'neck_z_index', 'body_y_index', 'body_z_index', 'breathing_index']
|
|
#current_pose_dict = dict(zip(keys, self.get_current_pose()))
|
|
#print(current_pose_dict)
|
|
# output settings to console.
|
|
|
|
dir_name = "data/images"
|
|
file_dialog = wx.FileDialog(self, "Choose an image", dir_name, "", "*.png", wx.FD_SAVE)
|
|
if file_dialog.ShowModal() == wx.ID_OK:
|
|
image_file_name = os.path.join(file_dialog.GetDirectory(), file_dialog.GetFilename())
|
|
try:
|
|
if os.path.exists(image_file_name):
|
|
message_dialog = wx.MessageDialog(self, f"Override {image_file_name}", "Manual Poser",
|
|
wx.YES_NO | wx.ICON_QUESTION)
|
|
result = message_dialog.ShowModal()
|
|
if result == wx.ID_YES:
|
|
self.save_last_numpy_image(image_file_name)
|
|
else:
|
|
self.save_last_numpy_image(image_file_name)
|
|
|
|
|
|
except:
|
|
message_dialog = wx.MessageDialog(self, f"Could not save {image_file_name}", "Manual Poser", wx.OK)
|
|
message_dialog.ShowModal()
|
|
message_dialog.Destroy()
|
|
file_dialog.Destroy()
|
|
|
|
def save_last_numpy_image(self, image_file_name):
|
|
numpy_image = self.last_output_numpy_image
|
|
pil_image = PIL.Image.fromarray(numpy_image, mode='RGBA')
|
|
os.makedirs(os.path.dirname(image_file_name), exist_ok=True)
|
|
pil_image.save(image_file_name)
|
|
|
|
|
|
data_dict = self.get_current_posedict() # Get values
|
|
json_file_path = os.path.splitext(image_file_name)[0] + ".json" # Generate JSON file path
|
|
|
|
filename_without_extension = os.path.splitext(os.path.basename(image_file_name))[0]
|
|
data_dict_with_filename = {filename_without_extension: data_dict} # Create a new dict with the filename as the key
|
|
|
|
with open(json_file_path, "w") as file:
|
|
json.dump(data_dict_with_filename, file, indent=4)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description='Manually pose a character image.')
|
|
parser.add_argument(
|
|
'--model',
|
|
type=str,
|
|
required=False,
|
|
default='separable_float',
|
|
choices=['standard_float', 'separable_float', 'standard_half', 'separable_half'],
|
|
help='The model to use.')
|
|
args = parser.parse_args()
|
|
|
|
device = torch.device('cuda')
|
|
try:
|
|
poser = load_poser(args.model, device)
|
|
except RuntimeError as e:
|
|
print(e)
|
|
sys.exit()
|
|
|
|
app = wx.App()
|
|
main_frame = MainFrame(poser, device)
|
|
main_frame.Show(True)
|
|
main_frame.timer.Start(30)
|
|
app.MainLoop()
|