diff --git a/OAI/types/common.py b/OAI/types/common.py index c0afb06..919b6e4 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -79,7 +79,7 @@ class CommonCompletionRequest(BaseModel): "repetition_penalty": self.repetition_penalty, "repetition_penalty_range": self.repetition_penalty_range, "repetition_decay": self.repetition_decay, - "mirostat": True if self.mirostat_mode == 2 else False, + "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, "mirostat_eta": self.mirostat_eta } diff --git a/main.py b/main.py index 8da2738..6da88d5 100644 --- a/main.py +++ b/main.py @@ -48,13 +48,14 @@ async def list_models(): models = get_model_list(model_path) - return models.json() + return models # Currently loaded model endpoint @app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) +@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_current_model(): model_card = ModelCard(id=model_container.get_model_path().name) - return model_card.json() + return model_card # Load model endpoint @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) @@ -84,17 +85,21 @@ async def load_model(data: ModelLoadRequest): else: loading_bar.next() - yield ModelLoadResponse( + response = ModelLoadResponse( module=module, modules=modules, status="processing" - ).json() + ) - yield ModelLoadResponse( + yield response.json(ensure_ascii=False) + + response = ModelLoadResponse( module=module, modules=modules, status="finished" - ).json() + ) + + yield response.json(ensure_ascii=False) return EventSourceResponse(generator()) @@ -112,7 +117,7 @@ async def encode_tokens(data: TokenEncodeRequest): tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist() response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) - return response.json() + return response # Decode tokens endpoint @app.post("/v1/token/decode", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @@ -120,7 +125,7 @@ async def decode_tokens(data: TokenDecodeRequest): message = model_container.get_tokens(None, data.tokens, **data.get_params()) response = TokenDecodeResponse(text=message) - return response.json() + return response # Completions endpoint @app.post("/v1/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @@ -139,14 +144,14 @@ async def generate_completion(request: Request, data: CompletionRequest): response = create_completion_response(part, model_path.name) - yield response.json() + yield response.json(ensure_ascii=False) return EventSourceResponse(generator()) else: response_text = model_container.generate(data.prompt, **data.to_gen_params()) response = create_completion_response(response_text, model_path.name) - return response.json() + return response # Chat completions endpoint @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @@ -172,14 +177,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest model_path.name ) - yield response.json() + yield response.json(ensure_ascii=False) return EventSourceResponse(generator()) else: response_text = model_container.generate(prompt, **data.to_gen_params()) response = create_chat_completion_response(response_text, model_path.name) - return response.json() + return response if __name__ == "__main__": # Initialize auth keys @@ -196,7 +201,7 @@ if __name__ == "__main__": model_config = config.get("model", {}) if "model_name" in model_config: - model_path = pathlib.Path(model_config.get("model", "models")) + model_path = pathlib.Path(model_config.get("model_dir", "models")) model_path = model_path / model_config["model_name"] model_container = ModelContainer(model_path, False, **model_config)