diff --git a/server.py b/server.py index 55eb69e..097652a 100644 --- a/server.py +++ b/server.py @@ -207,10 +207,7 @@ if "caption" in modules: if "summarize" in modules: print("Initializing a text summarization model...") - summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model) - summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained( - summarization_model, torch_dtype=torch_dtype - ).to(device) + summarization_pipeline = pipeline('summarization', model=summarization_model, device=device_string, torch_dtype=torch_dtype) if "sd" in modules and not sd_use_remote: from diffusers import StableDiffusionPipeline @@ -437,44 +434,20 @@ def caption_image(raw_image: Image) -> str: return caption -def summarize_chunks(text: str, params: dict) -> str: +def summarize_chunks(text: str) -> str: try: - return summarize(text, params) + return summarize(text) except IndexError: print( "Sequence length too large for model, cutting text in half and calling again" ) - new_params = params.copy() - new_params["max_length"] = new_params["max_length"] // 2 - new_params["min_length"] = new_params["min_length"] // 2 return summarize_chunks( - text[: (len(text) // 2)], new_params - ) + summarize_chunks(text[(len(text) // 2) :], new_params) + text[: (len(text) // 2)] + ) + summarize_chunks(text[(len(text) // 2) :]) -def summarize(text: str, params: dict) -> str: - # Tokenize input - inputs = summarization_tokenizer(text, return_tensors="pt").to(device) - token_count = len(inputs[0]) - - bad_words_ids = [ - summarization_tokenizer(bad_word, add_special_tokens=False).input_ids - for bad_word in params["bad_words"] - ] - summary_ids = summarization_transformer.generate( - inputs["input_ids"], - num_beams=2, - max_new_tokens=max(token_count, int(params["max_length"])), - min_new_tokens=min(token_count, int(params["min_length"])), - repetition_penalty=float(params["repetition_penalty"]), - temperature=float(params["temperature"]), - length_penalty=float(params["length_penalty"]), - bad_words_ids=bad_words_ids, - ) - summary = summarization_tokenizer.batch_decode( - summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True - )[0] - summary = normalize_string(summary) +def summarize(text: str) -> str: + summary = normalize_string(summarization_pipeline(text)[0]['summary_text']) return summary @@ -627,13 +600,8 @@ def api_summarize(): if "text" not in data or not isinstance(data["text"], str): abort(400, '"text" is required') - params = DEFAULT_SUMMARIZE_PARAMS.copy() - - if "params" in data and isinstance(data["params"], dict): - params.update(data["params"]) - print("Summary input:", data["text"], sep="\n") - summary = summarize_chunks(data["text"], params) + summary = summarize_chunks(data["text"]) print("Summary output:", summary, sep="\n") gc.collect() return jsonify({"summary": summary})