mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-27 09:52:03 +00:00
159 lines
4.5 KiB
Python
159 lines
4.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import Tuple, List, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
class PoseParameterCategory(Enum):
|
|
EYEBROW = 1
|
|
EYE = 2
|
|
IRIS_MORPH = 3
|
|
IRIS_ROTATION = 4
|
|
MOUTH = 5
|
|
FACE_ROTATION = 6
|
|
BODY_ROTATION = 7
|
|
BREATHING = 8
|
|
|
|
|
|
class PoseParameterGroup:
|
|
def __init__(self,
|
|
group_name: str,
|
|
parameter_index: int,
|
|
category: PoseParameterCategory,
|
|
arity: int = 1,
|
|
discrete: bool = False,
|
|
default_value: float = 0.0,
|
|
range: Optional[Tuple[float, float]] = None):
|
|
assert arity == 1 or arity == 2
|
|
if range is None:
|
|
range = (0.0, 1.0)
|
|
if arity == 1:
|
|
parameter_names = [group_name]
|
|
else:
|
|
parameter_names = [group_name + "_left", group_name + "_right"]
|
|
assert len(parameter_names) == arity
|
|
|
|
self.parameter_names = parameter_names
|
|
self.range = range
|
|
self.default_value = default_value
|
|
self.discrete = discrete
|
|
self.arity = arity
|
|
self.category = category
|
|
self.parameter_index = parameter_index
|
|
self.group_name = group_name
|
|
|
|
def get_arity(self) -> int:
|
|
return self.arity
|
|
|
|
def get_group_name(self) -> str:
|
|
return self.group_name
|
|
|
|
def get_parameter_names(self) -> List[str]:
|
|
return self.parameter_names
|
|
|
|
def is_discrete(self) -> bool:
|
|
return self.discrete
|
|
|
|
def get_range(self) -> Tuple[float, float]:
|
|
return self.range
|
|
|
|
def get_default_value(self):
|
|
return self.default_value
|
|
|
|
def get_parameter_index(self):
|
|
return self.parameter_index
|
|
|
|
def get_category(self) -> PoseParameterCategory:
|
|
return self.category
|
|
|
|
|
|
class PoseParameters:
|
|
def __init__(self, pose_parameter_groups: List[PoseParameterGroup]):
|
|
self.pose_parameter_groups = pose_parameter_groups
|
|
|
|
def get_parameter_index(self, name: str) -> int:
|
|
index = 0
|
|
for parameter_group in self.pose_parameter_groups:
|
|
for param_name in parameter_group.parameter_names:
|
|
if name == param_name:
|
|
return index
|
|
index += 1
|
|
raise RuntimeError("Cannot find parameter with name %s" % name)
|
|
|
|
def get_parameter_name(self, index: int) -> str:
|
|
assert index >= 0 and index < self.get_parameter_count()
|
|
|
|
for group in self.pose_parameter_groups:
|
|
if index < group.get_arity():
|
|
return group.get_parameter_names()[index]
|
|
index -= group.arity
|
|
|
|
raise RuntimeError("Something is wrong here!!!")
|
|
|
|
def get_pose_parameter_groups(self):
|
|
return self.pose_parameter_groups
|
|
|
|
def get_parameter_count(self):
|
|
count = 0
|
|
for group in self.pose_parameter_groups:
|
|
count += group.arity
|
|
return count
|
|
|
|
class Builder:
|
|
def __init__(self):
|
|
self.index = 0
|
|
self.pose_parameter_groups = []
|
|
|
|
def add_parameter_group(self,
|
|
group_name: str,
|
|
category: PoseParameterCategory,
|
|
arity: int = 1,
|
|
discrete: bool = False,
|
|
default_value: float = 0.0,
|
|
range: Optional[Tuple[float, float]] = None):
|
|
self.pose_parameter_groups.append(
|
|
PoseParameterGroup(
|
|
group_name,
|
|
self.index,
|
|
category,
|
|
arity,
|
|
discrete,
|
|
default_value,
|
|
range))
|
|
self.index += arity
|
|
return self
|
|
|
|
def build(self) -> 'PoseParameters':
|
|
return PoseParameters(self.pose_parameter_groups)
|
|
|
|
|
|
class Poser(ABC):
|
|
@abstractmethod
|
|
def get_image_size(self) -> int:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_output_length(self) -> int:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_pose_parameter_groups(self) -> List[PoseParameterGroup]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_num_parameters(self) -> int:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def pose(self, image: Tensor, pose: Tensor, output_index: int = 0) -> Tensor:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_posing_outputs(self, image: Tensor, pose: Tensor) -> List[Tensor]:
|
|
pass
|
|
|
|
def get_dtype(self) -> torch.dtype:
|
|
return torch.float
|