Chunked summarization

This commit is contained in:
SillyLossy
2023-04-14 13:41:48 +03:00
parent 747bb1252c
commit 28351a4ca8

View File

@@ -217,6 +217,17 @@ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
return caption
def summarize_chunks(text: str, params: dict) -> str:
try:
return summarize(text, params)
except IndexError:
print("Sequence length too large for model, cutting text in half and calling again")
new_params = params.copy()
new_params['max_length'] = new_params['max_length'] // 2
new_params['min_length'] = new_params['min_length'] // 2
return summarize_chunks(text[:(len(text) // 2)], new_params) + summarize_chunks(text[(len(text) // 2):], new_params)
def summarize(text: str, params: dict) -> str:
# Tokenize input
inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
@@ -229,8 +240,8 @@ def summarize(text: str, params: dict) -> str:
summary_ids = summarization_transformer.generate(
inputs["input_ids"],
num_beams=2,
min_length=min(token_count, int(params['min_length'])),
max_length=max(token_count, int(params['max_length'])),
max_new_tokens=max(token_count, int(params['max_length'])),
min_new_tokens=min(token_count, int(params['min_length'])),
repetition_penalty=float(params['repetition_penalty']),
temperature=float(params['temperature']),
length_penalty=float(params['length_penalty']),
@@ -361,7 +372,7 @@ def api_summarize():
if 'params' in data and isinstance(data['params'], dict):
params.update(data['params'])
summary = summarize(data['text'], params)
summary = summarize_chunks(data['text'], params)
return jsonify({'summary': summary})