mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-04-29 02:41:21 +00:00
API cleanliness
This commit is contained in:
41
server.py
41
server.py
@@ -628,6 +628,7 @@ def api_caption():
|
||||
@app.route("/api/summarize", methods=["POST"])
|
||||
@require_module("summarize")
|
||||
def api_summarize():
|
||||
"""Summarize the text posted in the request. Return the summary."""
|
||||
data = request.get_json()
|
||||
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
@@ -643,6 +644,10 @@ def api_summarize():
|
||||
@app.route("/api/classify", methods=["POST"])
|
||||
@require_module("classify")
|
||||
def api_classify():
|
||||
"""Perform sentiment analysis (classification) on the text posted in the request. Return the result.
|
||||
|
||||
Also, if `talkinghead` is enabled, automatically update its emotion based on the classification result.
|
||||
"""
|
||||
data = request.get_json()
|
||||
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
@@ -652,14 +657,18 @@ def api_classify():
|
||||
classification = classify_text(data["text"])
|
||||
print("Classification output:", classification, sep="\n")
|
||||
gc.collect()
|
||||
if "talkinghead" in modules: #send emotion to talkinghead
|
||||
talkinghead.setEmotion(classification)
|
||||
# TODO: Feature orthogonality: would be better if the client called the `set_emotion` endpoint explicitly
|
||||
# also when it uses `classify`, if it intends to update the talkinghead state.
|
||||
if "talkinghead" in modules: # send emotion to talkinghead
|
||||
print("Updating talkinghead emotion from classification results")
|
||||
talkinghead.set_emotion_from_classification(classification)
|
||||
return jsonify({"classification": classification})
|
||||
|
||||
|
||||
@app.route("/api/classify/labels", methods=["GET"])
|
||||
@require_module("classify")
|
||||
def api_classify_labels():
|
||||
"""Return the available classifier labels for text sentiment (character emotion)."""
|
||||
classification = classify_text("")
|
||||
labels = [x["label"] for x in classification]
|
||||
if "talkinghead" in modules:
|
||||
@@ -668,38 +677,48 @@ def api_classify_labels():
|
||||
|
||||
@app.route("/api/talkinghead/load", methods=["POST"])
|
||||
@require_module("talkinghead")
|
||||
def live_load():
|
||||
def api_talkinghead_load():
|
||||
"""Load the talkinghead sprite posted in the request. Resume animation if paused."""
|
||||
file = request.files['file']
|
||||
# convert stream to bytes and pass to talkinghead_load
|
||||
# convert stream to bytes and pass to talkinghead
|
||||
return talkinghead.talkinghead_load_file(file.stream)
|
||||
|
||||
@app.route('/api/talkinghead/unload')
|
||||
@require_module("talkinghead")
|
||||
def live_unload():
|
||||
def api_talkinghead_unload():
|
||||
"""Pause talkinghead animation. Can be enabled again via '/api/talkinghead/load'."""
|
||||
return talkinghead.unload()
|
||||
|
||||
@app.route('/api/talkinghead/start_talking')
|
||||
@require_module("talkinghead")
|
||||
def start_talking():
|
||||
def api_talkinghead_start_talking():
|
||||
"""Start the mouth animation for talking."""
|
||||
return talkinghead.start_talking()
|
||||
|
||||
@app.route('/api/talkinghead/stop_talking')
|
||||
@require_module("talkinghead")
|
||||
def stop_talking():
|
||||
def api_talkinghead_stop_talking():
|
||||
"""Stop the mouth animation for talking."""
|
||||
return talkinghead.stop_talking()
|
||||
|
||||
@app.route('/api/talkinghead/set_emotion', methods=["POST"])
|
||||
@require_module("talkinghead")
|
||||
def emote():
|
||||
def api_talkinghead_set_emotion():
|
||||
"""Set talkinghead character emotion to that posted in the request.
|
||||
|
||||
There is no getter, because SillyTavern keeps its state in the frontend
|
||||
and the plugins only act as slaves (in the technological sense of the word).
|
||||
"""
|
||||
data = request.get_json()
|
||||
if "emotion_name" not in data or not isinstance(data["emotion_name"], str):
|
||||
abort(400, '"emotion_name" is required')
|
||||
emotion_name = data["emotion_name"]
|
||||
return talkinghead.setEmotion([{"label": emotion_name, "score": 1.0}]) # mimic the `classify` API result
|
||||
return talkinghead.set_emotion(emotion_name)
|
||||
|
||||
@app.route('/api/talkinghead/result_feed')
|
||||
@require_module("talkinghead")
|
||||
def result_feed():
|
||||
def api_talkinghead_result_feed():
|
||||
"""Live character output. Stream of video frames, each as a PNG encoded image."""
|
||||
return talkinghead.result_feed()
|
||||
|
||||
@app.route("/api/image", methods=["POST"])
|
||||
@@ -1139,5 +1158,5 @@ if args.share:
|
||||
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
|
||||
|
||||
ignore_auth.append(tts_play_sample)
|
||||
ignore_auth.append(result_feed)
|
||||
ignore_auth.append(api_talkinghead_result_feed)
|
||||
app.run(host=host, port=port)
|
||||
|
||||
@@ -6,6 +6,13 @@ This module implements the live animation backend and serves the API. For usage,
|
||||
If you want to play around with THA3 expressions in a standalone app, see `manual_poser.py`.
|
||||
"""
|
||||
|
||||
__all__ = ["set_emotion_from_classification", "set_emotion",
|
||||
"unload",
|
||||
"start_talking", "stop_talking",
|
||||
"result_feed",
|
||||
"talkinghead_load_file",
|
||||
"launch"]
|
||||
|
||||
import atexit
|
||||
import io
|
||||
import logging
|
||||
@@ -61,54 +68,68 @@ global_reload_image = None
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
def setEmotion(_emotion: Dict[str, float]) -> None:
|
||||
def set_emotion_from_classification(emotion_scores: List[Dict[str, Union[str, float]]]) -> str:
|
||||
"""Set the current emotion of the character based on sentiment analysis results.
|
||||
|
||||
Currently, we pick the emotion with the highest confidence score.
|
||||
|
||||
The `set_emotion` API endpoint also uses this function to set the current emotion,
|
||||
with a manually formatted dictionary containing just one entry.
|
||||
`emotion_scores`: results from classify module: [{"label": emotion0, "score": confidence0}, ...]
|
||||
|
||||
_emotion: result of sentiment analysis: {emotion0: confidence0, ...}
|
||||
Return a status message for passing over HTTP.
|
||||
"""
|
||||
global current_emotion
|
||||
|
||||
highest_score = float("-inf")
|
||||
highest_label = None
|
||||
|
||||
for item in _emotion:
|
||||
for item in emotion_scores:
|
||||
if item["score"] > highest_score:
|
||||
highest_score = item["score"]
|
||||
highest_label = item["label"]
|
||||
logger.info(f"set_emotion_from_classification: winning score: {highest_label} = {highest_score}")
|
||||
return set_emotion(highest_label)
|
||||
|
||||
if highest_label not in global_animator_instance.emotions:
|
||||
logger.warning(f"setEmotion: emotion '{highest_label}' does not exist, setting to 'neutral'")
|
||||
highest_label = "neutral"
|
||||
def set_emotion(emotion: str) -> str:
|
||||
"""Set the current emotion of the character.
|
||||
|
||||
logger.info(f"setEmotion: applying emotion {highest_label}")
|
||||
current_emotion = highest_label
|
||||
return f"emotion set to {highest_label}"
|
||||
Return a status message for passing over HTTP.
|
||||
"""
|
||||
global current_emotion
|
||||
|
||||
if emotion not in global_animator_instance.emotions:
|
||||
logger.warning(f"set_emotion: specified emotion '{emotion}' does not exist, selecting 'neutral'")
|
||||
emotion = "neutral"
|
||||
|
||||
logger.info(f"set_emotion: applying emotion {emotion}")
|
||||
current_emotion = emotion
|
||||
return f"emotion set to {emotion}"
|
||||
|
||||
def unload() -> str:
|
||||
"""Stop animation."""
|
||||
"""Stop animation.
|
||||
|
||||
Return a status message for passing over HTTP.
|
||||
"""
|
||||
global animation_running
|
||||
animation_running = False
|
||||
logger.info("unload: animation paused")
|
||||
return "Animation Paused"
|
||||
return "animation paused"
|
||||
|
||||
def start_talking() -> str:
|
||||
"""Start talking animation."""
|
||||
"""Start talking animation.
|
||||
|
||||
Return a status message for passing over HTTP.
|
||||
"""
|
||||
global is_talking
|
||||
is_talking = True
|
||||
logger.debug("start_talking called")
|
||||
return "started"
|
||||
return "talking started"
|
||||
|
||||
def stop_talking() -> str:
|
||||
"""Stop talking animation."""
|
||||
"""Stop talking animation.
|
||||
|
||||
Return a status message for passing over HTTP.
|
||||
"""
|
||||
global is_talking
|
||||
is_talking = False
|
||||
logger.debug("stop_talking called")
|
||||
return "stopped"
|
||||
return "talking stopped"
|
||||
|
||||
# There are three tasks we must do each frame:
|
||||
#
|
||||
|
||||
Reference in New Issue
Block a user