mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-01-26 17:20:04 +00:00
Use pipeline for text summarization
This commit is contained in:
48
server.py
48
server.py
@@ -207,10 +207,7 @@ if "caption" in modules:
|
||||
|
||||
if "summarize" in modules:
|
||||
print("Initializing a text summarization model...")
|
||||
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
|
||||
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
summarization_model, torch_dtype=torch_dtype
|
||||
).to(device)
|
||||
summarization_pipeline = pipeline('summarization', model=summarization_model, device=device_string, torch_dtype=torch_dtype)
|
||||
|
||||
if "sd" in modules and not sd_use_remote:
|
||||
from diffusers import StableDiffusionPipeline
|
||||
@@ -437,44 +434,20 @@ def caption_image(raw_image: Image) -> str:
|
||||
return caption
|
||||
|
||||
|
||||
def summarize_chunks(text: str, params: dict) -> str:
|
||||
def summarize_chunks(text: str) -> str:
|
||||
try:
|
||||
return summarize(text, params)
|
||||
return summarize(text)
|
||||
except IndexError:
|
||||
print(
|
||||
"Sequence length too large for model, cutting text in half and calling again"
|
||||
)
|
||||
new_params = params.copy()
|
||||
new_params["max_length"] = new_params["max_length"] // 2
|
||||
new_params["min_length"] = new_params["min_length"] // 2
|
||||
return summarize_chunks(
|
||||
text[: (len(text) // 2)], new_params
|
||||
) + summarize_chunks(text[(len(text) // 2) :], new_params)
|
||||
text[: (len(text) // 2)]
|
||||
) + summarize_chunks(text[(len(text) // 2) :])
|
||||
|
||||
|
||||
def summarize(text: str, params: dict) -> str:
|
||||
# Tokenize input
|
||||
inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
|
||||
token_count = len(inputs[0])
|
||||
|
||||
bad_words_ids = [
|
||||
summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
|
||||
for bad_word in params["bad_words"]
|
||||
]
|
||||
summary_ids = summarization_transformer.generate(
|
||||
inputs["input_ids"],
|
||||
num_beams=2,
|
||||
max_new_tokens=max(token_count, int(params["max_length"])),
|
||||
min_new_tokens=min(token_count, int(params["min_length"])),
|
||||
repetition_penalty=float(params["repetition_penalty"]),
|
||||
temperature=float(params["temperature"]),
|
||||
length_penalty=float(params["length_penalty"]),
|
||||
bad_words_ids=bad_words_ids,
|
||||
)
|
||||
summary = summarization_tokenizer.batch_decode(
|
||||
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
summary = normalize_string(summary)
|
||||
def summarize(text: str) -> str:
|
||||
summary = normalize_string(summarization_pipeline(text)[0]['summary_text'])
|
||||
return summary
|
||||
|
||||
|
||||
@@ -627,13 +600,8 @@ def api_summarize():
|
||||
if "text" not in data or not isinstance(data["text"], str):
|
||||
abort(400, '"text" is required')
|
||||
|
||||
params = DEFAULT_SUMMARIZE_PARAMS.copy()
|
||||
|
||||
if "params" in data and isinstance(data["params"], dict):
|
||||
params.update(data["params"])
|
||||
|
||||
print("Summary input:", data["text"], sep="\n")
|
||||
summary = summarize_chunks(data["text"], params)
|
||||
summary = summarize_chunks(data["text"])
|
||||
print("Summary output:", summary, sep="\n")
|
||||
gc.collect()
|
||||
return jsonify({"summary": summary})
|
||||
|
||||
Reference in New Issue
Block a user