API cleanliness

This commit is contained in:
Juha Jeronen
2024-01-08 14:46:09 +02:00
parent fc58895a62
commit e94521251f
2 changed files with 71 additions and 31 deletions

View File

@@ -628,6 +628,7 @@ def api_caption():
@app.route("/api/summarize", methods=["POST"])
@require_module("summarize")
def api_summarize():
"""Summarize the text posted in the request. Return the summary."""
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
@@ -643,6 +644,10 @@ def api_summarize():
@app.route("/api/classify", methods=["POST"])
@require_module("classify")
def api_classify():
"""Perform sentiment analysis (classification) on the text posted in the request. Return the result.
Also, if `talkinghead` is enabled, automatically update its emotion based on the classification result.
"""
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
@@ -652,14 +657,18 @@ def api_classify():
classification = classify_text(data["text"])
print("Classification output:", classification, sep="\n")
gc.collect()
if "talkinghead" in modules: #send emotion to talkinghead
talkinghead.setEmotion(classification)
# TODO: Feature orthogonality: would be better if the client called the `set_emotion` endpoint explicitly
# also when it uses `classify`, if it intends to update the talkinghead state.
if "talkinghead" in modules: # send emotion to talkinghead
print("Updating talkinghead emotion from classification results")
talkinghead.set_emotion_from_classification(classification)
return jsonify({"classification": classification})
@app.route("/api/classify/labels", methods=["GET"])
@require_module("classify")
def api_classify_labels():
"""Return the available classifier labels for text sentiment (character emotion)."""
classification = classify_text("")
labels = [x["label"] for x in classification]
if "talkinghead" in modules:
@@ -668,38 +677,48 @@ def api_classify_labels():
@app.route("/api/talkinghead/load", methods=["POST"])
@require_module("talkinghead")
def live_load():
def api_talkinghead_load():
"""Load the talkinghead sprite posted in the request. Resume animation if paused."""
file = request.files['file']
# convert stream to bytes and pass to talkinghead_load
# convert stream to bytes and pass to talkinghead
return talkinghead.talkinghead_load_file(file.stream)
@app.route('/api/talkinghead/unload')
@require_module("talkinghead")
def live_unload():
def api_talkinghead_unload():
"""Pause talkinghead animation. Can be enabled again via '/api/talkinghead/load'."""
return talkinghead.unload()
@app.route('/api/talkinghead/start_talking')
@require_module("talkinghead")
def start_talking():
def api_talkinghead_start_talking():
"""Start the mouth animation for talking."""
return talkinghead.start_talking()
@app.route('/api/talkinghead/stop_talking')
@require_module("talkinghead")
def stop_talking():
def api_talkinghead_stop_talking():
"""Stop the mouth animation for talking."""
return talkinghead.stop_talking()
@app.route('/api/talkinghead/set_emotion', methods=["POST"])
@require_module("talkinghead")
def emote():
def api_talkinghead_set_emotion():
"""Set talkinghead character emotion to that posted in the request.
There is no getter, because SillyTavern keeps its state in the frontend
and the plugins only act as slaves (in the technological sense of the word).
"""
data = request.get_json()
if "emotion_name" not in data or not isinstance(data["emotion_name"], str):
abort(400, '"emotion_name" is required')
emotion_name = data["emotion_name"]
return talkinghead.setEmotion([{"label": emotion_name, "score": 1.0}]) # mimic the `classify` API result
return talkinghead.set_emotion(emotion_name)
@app.route('/api/talkinghead/result_feed')
@require_module("talkinghead")
def result_feed():
def api_talkinghead_result_feed():
"""Live character output. Stream of video frames, each as a PNG encoded image."""
return talkinghead.result_feed()
@app.route("/api/image", methods=["POST"])
@@ -1139,5 +1158,5 @@ if args.share:
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
ignore_auth.append(tts_play_sample)
ignore_auth.append(result_feed)
ignore_auth.append(api_talkinghead_result_feed)
app.run(host=host, port=port)

View File

@@ -6,6 +6,13 @@ This module implements the live animation backend and serves the API. For usage,
If you want to play around with THA3 expressions in a standalone app, see `manual_poser.py`.
"""
__all__ = ["set_emotion_from_classification", "set_emotion",
"unload",
"start_talking", "stop_talking",
"result_feed",
"talkinghead_load_file",
"launch"]
import atexit
import io
import logging
@@ -61,54 +68,68 @@ global_reload_image = None
app = Flask(__name__)
CORS(app)
def setEmotion(_emotion: Dict[str, float]) -> None:
def set_emotion_from_classification(emotion_scores: List[Dict[str, Union[str, float]]]) -> str:
"""Set the current emotion of the character based on sentiment analysis results.
Currently, we pick the emotion with the highest confidence score.
The `set_emotion` API endpoint also uses this function to set the current emotion,
with a manually formatted dictionary containing just one entry.
`emotion_scores`: results from classify module: [{"label": emotion0, "score": confidence0}, ...]
_emotion: result of sentiment analysis: {emotion0: confidence0, ...}
Return a status message for passing over HTTP.
"""
global current_emotion
highest_score = float("-inf")
highest_label = None
for item in _emotion:
for item in emotion_scores:
if item["score"] > highest_score:
highest_score = item["score"]
highest_label = item["label"]
logger.info(f"set_emotion_from_classification: winning score: {highest_label} = {highest_score}")
return set_emotion(highest_label)
if highest_label not in global_animator_instance.emotions:
logger.warning(f"setEmotion: emotion '{highest_label}' does not exist, setting to 'neutral'")
highest_label = "neutral"
def set_emotion(emotion: str) -> str:
"""Set the current emotion of the character.
logger.info(f"setEmotion: applying emotion {highest_label}")
current_emotion = highest_label
return f"emotion set to {highest_label}"
Return a status message for passing over HTTP.
"""
global current_emotion
if emotion not in global_animator_instance.emotions:
logger.warning(f"set_emotion: specified emotion '{emotion}' does not exist, selecting 'neutral'")
emotion = "neutral"
logger.info(f"set_emotion: applying emotion {emotion}")
current_emotion = emotion
return f"emotion set to {emotion}"
def unload() -> str:
"""Stop animation."""
"""Stop animation.
Return a status message for passing over HTTP.
"""
global animation_running
animation_running = False
logger.info("unload: animation paused")
return "Animation Paused"
return "animation paused"
def start_talking() -> str:
"""Start talking animation."""
"""Start talking animation.
Return a status message for passing over HTTP.
"""
global is_talking
is_talking = True
logger.debug("start_talking called")
return "started"
return "talking started"
def stop_talking() -> str:
"""Stop talking animation."""
"""Stop talking animation.
Return a status message for passing over HTTP.
"""
global is_talking
is_talking = False
logger.debug("stop_talking called")
return "stopped"
return "talking stopped"
# There are three tasks we must do each frame:
#