diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..26d3352
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,3 @@
+# Default ignored files
+/shelf/
+/workspace.xml
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..d378e48
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,20 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..6931d08
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,4 @@
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..f9dcddc
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/tabbyAPI.iml b/.idea/tabbyAPI.iml
new file mode 100644
index 0000000..74d515a
--- /dev/null
+++ b/.idea/tabbyAPI.iml
@@ -0,0 +1,10 @@
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 1fcb2ce..bd7ca42 100644
--- a/README.md
+++ b/README.md
@@ -74,13 +74,19 @@ The tabbyAPI application provides the following endpoint:
### Example Request (using `curl`)
-curl -X POST "http://localhost:8000/generate-text" -H "Content-Type: application/json" -d '{
- "model": "your_model_name",
- "messages": [
- {"role": "user", "content": "Say this is a test!"}
- ],
- "temperature": 0.7
-}'
+curl http://127.0.0.1:8000/generate-text \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "Your_Model_Path",
+ "prompt": "A tabby is a",
+ "max_tokens": 200,
+ "temperature": 1,
+ "top_p": 0.9,
+ "seed": 10,
+ "stream": true,
+ "token_repetition_penalty": 0.5,
+ "stop": ["###"]
+ }'
### Parameter Guide
diff --git a/llm.py b/llm.py
index b0a2336..19167a3 100644
--- a/llm.py
+++ b/llm.py
@@ -11,6 +11,8 @@ from exllamav2.generator import (
ExLlamaV2Sampler
)
import time
+
+
class ModelManager:
def __init__(self, model_directory: str = None):
if model_directory is None:
@@ -24,12 +26,25 @@ class ModelManager:
self.model.load_autosplit(self.cache)
self.tokenizer = ExLlamaV2Tokenizer(self.config)
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
- def generate_text(self, prompt: str, max_new_tokens: int = 150,seed: int = random.randint(0,999999) ):
+
+ def generate_text(self,
+ prompt: str,
+ max_tokens: int = 150,
+ temperature=0.5,
+ seed: int = random.randint(0, 999999),
+ token_repetition_penalty: float = 1.0,
+ stop: list = None):
try:
self.generator.warmup()
time_begin = time.time()
+ settings = ExLlamaV2Sampler.Settings()
+ settings.token_repetition_penalty = token_repetition_penalty
+
+ if stop:
+ settings.stop_sequence = stop
+
output = self.generator.generate_simple(
- prompt, ExLlamaV2Sampler.Settings(), max_new_tokens, seed=seed
+ prompt, settings, max_tokens, seed=seed
)
time_end = time.time()
time_total = time_end - time_begin
diff --git a/main.py b/main.py
index 1e0e66a..7efab37 100644
--- a/main.py
+++ b/main.py
@@ -7,13 +7,19 @@ from uvicorn import run
app = FastAPI()
# Initialize the modelManager with a default model path
-default_model_path = "~/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
+default_model_path = "/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
modelManager = ModelManager(default_model_path)
-
+print(output)
class TextRequest(BaseModel):
- model: str
- messages: list[dict]
- temperature: float
+ model: str = None # Make the "model" field optional with a default value of None
+ prompt: str
+ max_tokens: int = 200
+ temperature: float = 1
+ top_p: float = 0.9
+ seed: int = 10
+ stream: bool = False
+ token_repetition_penalty: float = 1.0
+ stop: list = None
class TextResponse(BaseModel):
response: str
@@ -23,20 +29,9 @@ class TextResponse(BaseModel):
def generate_text(request: TextRequest):
global modelManager
try:
- model_path = request.model
-
- if model_path and model_path != modelManager.config.model_path:
- # Check if the specified model path exists
- if not os.path.exists(model_path):
- raise HTTPException(status_code=400, detail="Model path does not exist")
-
- # Reinitialize the modelManager with the new model path
- modelManager = ModelManager(model_path)
-
- messages = request.messages
- user_message = next(msg["content"] for msg in messages if msg["role"] == "user")
-
- output, generation_time = modelManager.generate_text(user_message)
+ prompt = request.prompt # Get the prompt from the request
+ user_message = prompt # Assuming that prompt is equivalent to the user's message
+ output, generation_time = modelManager.generate_text(prompt=user_message)
return {"response": output, "generation_time": generation_time}
except RuntimeError as e:
raise HTTPException(status_code=500, detail=str(e))