mirror of
https://github.com/SillyTavern/SillyTavern-Extras.git
synced 2026-03-02 18:10:04 +00:00
27 lines
887 B
Python
27 lines
887 B
Python
from transformers import (
|
|
AutoModelForTokenClassification,
|
|
AutoTokenizer,
|
|
TokenClassificationPipeline,
|
|
)
|
|
from transformers.pipelines import AggregationStrategy
|
|
import numpy as np
|
|
|
|
|
|
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
|
def __init__(self, model, *args, **kwargs):
|
|
super().__init__(
|
|
model=AutoModelForTokenClassification.from_pretrained(model),
|
|
tokenizer=AutoTokenizer.from_pretrained(model),
|
|
*args,
|
|
**kwargs
|
|
)
|
|
|
|
def postprocess(self, model_outputs):
|
|
results = super().postprocess(
|
|
model_outputs=model_outputs,
|
|
aggregation_strategy=AggregationStrategy.SIMPLE
|
|
if self.model.config.model_type == "roberta"
|
|
else AggregationStrategy.FIRST,
|
|
)
|
|
return np.unique([result.get("word").strip() for result in results])
|