mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-05-01 03:41:24 +00:00
Fix colab running
This commit is contained in:
@@ -154,7 +154,7 @@ def load_extensions():
|
||||
|
||||
|
||||
# AI stuff
|
||||
def classify_text(text: str) -> list[dict]:
|
||||
def classify_text(text: str) -> list:
|
||||
output = bert_classifier(text)[0]
|
||||
return sorted(output, key=lambda x: x['score'], reverse=True)
|
||||
|
||||
@@ -198,7 +198,7 @@ def normalize_string(input: str) -> str:
|
||||
return output
|
||||
|
||||
|
||||
def extract_keywords(text: str) -> list[str]:
|
||||
def extract_keywords(text: str) -> list:
|
||||
punctuation = '(){}[]\n\r<>'
|
||||
trans = str.maketrans(punctuation, ' '*len(punctuation))
|
||||
text = text.translate(trans)
|
||||
@@ -206,7 +206,7 @@ def extract_keywords(text: str) -> list[str]:
|
||||
return list(keyphrase_pipe(text))
|
||||
|
||||
|
||||
def generate_prompt(keywords: list[str], length: int = 100, num: int = 4) -> str:
|
||||
def generate_prompt(keywords: list, length: int = 100, num: int = 4) -> str:
|
||||
prompt = ', '.join(keywords)
|
||||
outs = prompt_generator(prompt, max_length=length, num_return_sequences=num, do_sample=True,
|
||||
repetition_penalty=1.2, temperature=0.7, top_k=4, early_stopping=True)
|
||||
|
||||
Reference in New Issue
Block a user