Use pipeline for text summarization

This commit is contained in:
Cohee
2023-12-20 01:57:38 +02:00
parent 423f51e3f8
commit 47a5489142

View File

@@ -207,10 +207,7 @@ if "caption" in modules:
if "summarize" in modules:
print("Initializing a text summarization model...")
summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
summarization_model, torch_dtype=torch_dtype
).to(device)
summarization_pipeline = pipeline('summarization', model=summarization_model, device=device_string, torch_dtype=torch_dtype)
if "sd" in modules and not sd_use_remote:
from diffusers import StableDiffusionPipeline
@@ -437,44 +434,20 @@ def caption_image(raw_image: Image) -> str:
return caption
def summarize_chunks(text: str, params: dict) -> str:
def summarize_chunks(text: str) -> str:
try:
return summarize(text, params)
return summarize(text)
except IndexError:
print(
"Sequence length too large for model, cutting text in half and calling again"
)
new_params = params.copy()
new_params["max_length"] = new_params["max_length"] // 2
new_params["min_length"] = new_params["min_length"] // 2
return summarize_chunks(
text[: (len(text) // 2)], new_params
) + summarize_chunks(text[(len(text) // 2) :], new_params)
text[: (len(text) // 2)]
) + summarize_chunks(text[(len(text) // 2) :])
def summarize(text: str, params: dict) -> str:
# Tokenize input
inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
token_count = len(inputs[0])
bad_words_ids = [
summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
for bad_word in params["bad_words"]
]
summary_ids = summarization_transformer.generate(
inputs["input_ids"],
num_beams=2,
max_new_tokens=max(token_count, int(params["max_length"])),
min_new_tokens=min(token_count, int(params["min_length"])),
repetition_penalty=float(params["repetition_penalty"]),
temperature=float(params["temperature"]),
length_penalty=float(params["length_penalty"]),
bad_words_ids=bad_words_ids,
)
summary = summarization_tokenizer.batch_decode(
summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
summary = normalize_string(summary)
def summarize(text: str) -> str:
summary = normalize_string(summarization_pipeline(text)[0]['summary_text'])
return summary
@@ -627,13 +600,8 @@ def api_summarize():
if "text" not in data or not isinstance(data["text"], str):
abort(400, '"text" is required')
params = DEFAULT_SUMMARIZE_PARAMS.copy()
if "params" in data and isinstance(data["params"], dict):
params.update(data["params"])
print("Summary input:", data["text"], sep="\n")
summary = summarize_chunks(data["text"], params)
summary = summarize_chunks(data["text"])
print("Summary output:", summary, sep="\n")
gc.collect()
return jsonify({"summary": summary})