From ab84b01fdf1123811412ce80c64590cfd8162a16 Mon Sep 17 00:00:00 2001 From: Splice86 Date: Fri, 10 Nov 2023 00:39:08 -0600 Subject: [PATCH] Updated readme --- README.md | 32 ++++++++++++++++++++++++++++++++ main.py | 24 +++++++++++++++--------- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 008bbe1..1fcb2ce 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,35 @@ curl -X POST "http://localhost:8000/generate-text" -H "Content-Type: application ], "temperature": 0.7 }' + + +### Parameter Guide + +*note* This stuff still needs to be expanded and updated + +{ + "prompt": "A tabby is a", + "max_tokens": 200, + "temperature": 1, + "top_p": 0.9, + "seed": 10, + "stream": true, + "token_repetition_penalty": 0.5, + "stop": ["###"] +} + +prompt: This is the initial text or message that sets the context for the generated completions. + +max_tokens: It defines the maximum number of tokens (words or characters) you want in the generated text. + +temperature: The temperature parameter controls the randomness of the output. + +top_p: The top_p parameter controls the diversity of the output. + +seed: This parameter is set to 10. It is a seed value that helps to reproduce the same results if provided with the same seed. + +stream: A boolean value set to true. It enables Server-Sent Events (SSE) streaming. + +token_repetition_penalty: This parameter controls the penalty for token repetitions in the generated text. + +stop: An array of strings that, if present in the generated text, will signal the model to stop generating. diff --git a/main.py b/main.py index 1a36284..1e0e66a 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,4 @@ -# main.py - +import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from llm import ModelManager @@ -7,8 +6,9 @@ from uvicorn import run app = FastAPI() -# Example: Using a different model directory -modelManager = ModelManager("/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2") +# Initialize the modelManager with a default model path +default_model_path = "~/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2" +modelManager = ModelManager(default_model_path) class TextRequest(BaseModel): model: str @@ -21,15 +21,21 @@ class TextResponse(BaseModel): @app.post("/generate-text", response_model=TextResponse) def generate_text(request: TextRequest): + global modelManager try: - #model_path = request.model # You can use this path to load a specific model if needed - messages = request.messages - #temperature = request.temperature + model_path = request.model - # Assuming you need to extract the user's message from the messages list + if model_path and model_path != modelManager.config.model_path: + # Check if the specified model path exists + if not os.path.exists(model_path): + raise HTTPException(status_code=400, detail="Model path does not exist") + + # Reinitialize the modelManager with the new model path + modelManager = ModelManager(model_path) + + messages = request.messages user_message = next(msg["content"] for msg in messages if msg["role"] == "user") - # You can then use user_message as the prompt for generation output, generation_time = modelManager.generate_text(user_message) return {"response": output, "generation_time": generation_time} except RuntimeError as e: