API cleanliness

This commit is contained in:
Juha Jeronen
2024-01-08 14:46:09 +02:00
parent fc58895a62
commit e94521251f
2 changed files with 71 additions and 31 deletions

View File

@@ -628,6 +628,7 @@ def api_caption():
@app.route("/api/summarize", methods=["POST"])
@require_module("summarize")
def api_summarize():
"""Summarize the text posted in the request. Return the summary."""
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
@@ -643,6 +644,10 @@ def api_summarize():
@app.route("/api/classify", methods=["POST"])
@require_module("classify")
def api_classify():
"""Perform sentiment analysis (classification) on the text posted in the request. Return the result.
Also, if `talkinghead` is enabled, automatically update its emotion based on the classification result.
"""
data = request.get_json()
if "text" not in data or not isinstance(data["text"], str):
@@ -652,14 +657,18 @@ def api_classify():
classification = classify_text(data["text"])
print("Classification output:", classification, sep="\n")
gc.collect()
if "talkinghead" in modules: #send emotion to talkinghead
talkinghead.setEmotion(classification)
# TODO: Feature orthogonality: would be better if the client called the `set_emotion` endpoint explicitly
# also when it uses `classify`, if it intends to update the talkinghead state.
if "talkinghead" in modules: # send emotion to talkinghead
print("Updating talkinghead emotion from classification results")
talkinghead.set_emotion_from_classification(classification)
return jsonify({"classification": classification})
@app.route("/api/classify/labels", methods=["GET"])
@require_module("classify")
def api_classify_labels():
"""Return the available classifier labels for text sentiment (character emotion)."""
classification = classify_text("")
labels = [x["label"] for x in classification]
if "talkinghead" in modules:
@@ -668,38 +677,48 @@ def api_classify_labels():
@app.route("/api/talkinghead/load", methods=["POST"])
@require_module("talkinghead")
def live_load():
def api_talkinghead_load():
"""Load the talkinghead sprite posted in the request. Resume animation if paused."""
file = request.files['file']
# convert stream to bytes and pass to talkinghead_load
# convert stream to bytes and pass to talkinghead
return talkinghead.talkinghead_load_file(file.stream)
@app.route('/api/talkinghead/unload')
@require_module("talkinghead")
def live_unload():
def api_talkinghead_unload():
"""Pause talkinghead animation. Can be enabled again via '/api/talkinghead/load'."""
return talkinghead.unload()
@app.route('/api/talkinghead/start_talking')
@require_module("talkinghead")
def start_talking():
def api_talkinghead_start_talking():
"""Start the mouth animation for talking."""
return talkinghead.start_talking()
@app.route('/api/talkinghead/stop_talking')
@require_module("talkinghead")
def stop_talking():
def api_talkinghead_stop_talking():
"""Stop the mouth animation for talking."""
return talkinghead.stop_talking()
@app.route('/api/talkinghead/set_emotion', methods=["POST"])
@require_module("talkinghead")
def emote():
def api_talkinghead_set_emotion():
"""Set talkinghead character emotion to that posted in the request.
There is no getter, because SillyTavern keeps its state in the frontend
and the plugins only act as slaves (in the technological sense of the word).
"""
data = request.get_json()
if "emotion_name" not in data or not isinstance(data["emotion_name"], str):
abort(400, '"emotion_name" is required')
emotion_name = data["emotion_name"]
return talkinghead.setEmotion([{"label": emotion_name, "score": 1.0}]) # mimic the `classify` API result
return talkinghead.set_emotion(emotion_name)
@app.route('/api/talkinghead/result_feed')
@require_module("talkinghead")
def result_feed():
def api_talkinghead_result_feed():
"""Live character output. Stream of video frames, each as a PNG encoded image."""
return talkinghead.result_feed()
@app.route("/api/image", methods=["POST"])
@@ -1139,5 +1158,5 @@ if args.share:
print(f"{Fore.GREEN}{Style.NORMAL}Running on: {cloudflare}{Style.RESET_ALL}")
ignore_auth.append(tts_play_sample)
ignore_auth.append(result_feed)
ignore_auth.append(api_talkinghead_result_feed)
app.run(host=host, port=port)