diff --git a/model.py b/model.py index a5cb436..a5d4e2b 100644 --- a/model.py +++ b/model.py @@ -333,9 +333,12 @@ class ModelContainer: return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0] def generate(self, prompt: str, **kwargs): - gen = list(self.generate_gen(prompt, **kwargs)) - reponse = "".join(map(lambda o: o[0], gen)) - return reponse, gen[-1][1], gen[-1][2] + generation = list(self.generate_gen(prompt, **kwargs)) + if generation: + response = "".join(map(lambda chunk: chunk[0], generation)) + return response, generation[-1][1], generation[-1][2] + else: + return "", 0, 0 def generate_gen(self, prompt: str, **kwargs): """