diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index e88fc6b..61726e2 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,6 +3,7 @@ import asyncio import pathlib from asyncio import CancelledError +from copy import deepcopy from typing import List, Optional from uuid import uuid4 @@ -242,12 +243,21 @@ async def generate_chat_completion( prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path ): gen_tasks: List[asyncio.Task] = [] + gen_params = data.to_gen_params() try: - for _ in range(0, data.n): + for n in range(0, data.n): + + # Deepcopy gen params above the first index + # to ensure nested structures aren't shared + if n > 0: + task_gen_params = deepcopy(gen_params) + else: + task_gen_params = gen_params + gen_tasks.append( asyncio.create_task( - model.container.generate(prompt, **data.to_gen_params()) + model.container.generate(prompt, **task_gen_params) ) ) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index b86ccd5..5034fbc 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -3,6 +3,7 @@ import asyncio import pathlib from asyncio import CancelledError +from copy import deepcopy from fastapi import HTTPException, Request from typing import List, Optional @@ -111,12 +112,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) """Non-streaming generate for completions""" gen_tasks: List[asyncio.Task] = [] + gen_params = data.to_gen_params() try: - for _ in range(0, data.n): + for n in range(0, data.n): + + # Deepcopy gen params above the first index + # to ensure nested structures aren't shared + if n > 0: + task_gen_params = deepcopy(gen_params) + else: + task_gen_params = gen_params + gen_tasks.append( asyncio.create_task( - model.container.generate(data.prompt, **data.to_gen_params()) + model.container.generate(data.prompt, **task_gen_params) ) )