mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Update to README and other minor changes
This commit is contained in:
3
.idea/.gitignore
generated
vendored
Normal file
3
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
20
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
20
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,20 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="7">
|
||||
<item index="0" class="java.lang.String" itemvalue="scipy" />
|
||||
<item index="1" class="java.lang.String" itemvalue="transformers" />
|
||||
<item index="2" class="java.lang.String" itemvalue="sounddevice" />
|
||||
<item index="3" class="java.lang.String" itemvalue="matplotlib" />
|
||||
<item index="4" class="java.lang.String" itemvalue="librosa" />
|
||||
<item index="5" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="6" class="java.lang.String" itemvalue="flask" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
4
.idea/misc.xml
generated
Normal file
4
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (tabbyAPI)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/tabbyAPI.iml" filepath="$PROJECT_DIR$/.idea/tabbyAPI.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
10
.idea/tabbyAPI.iml
generated
Normal file
10
.idea/tabbyAPI.iml
generated
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
20
README.md
20
README.md
@@ -74,13 +74,19 @@ The tabbyAPI application provides the following endpoint:
|
||||
### Example Request (using `curl`)
|
||||
|
||||
|
||||
curl -X POST "http://localhost:8000/generate-text" -H "Content-Type: application/json" -d '{
|
||||
"model": "your_model_name",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Say this is a test!"}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
curl http://127.0.0.1:8000/generate-text \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "Your_Model_Path",
|
||||
"prompt": "A tabby is a",
|
||||
"max_tokens": 200,
|
||||
"temperature": 1,
|
||||
"top_p": 0.9,
|
||||
"seed": 10,
|
||||
"stream": true,
|
||||
"token_repetition_penalty": 0.5,
|
||||
"stop": ["###"]
|
||||
}'
|
||||
|
||||
|
||||
### Parameter Guide
|
||||
|
||||
19
llm.py
19
llm.py
@@ -11,6 +11,8 @@ from exllamav2.generator import (
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
import time
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, model_directory: str = None):
|
||||
if model_directory is None:
|
||||
@@ -24,12 +26,25 @@ class ModelManager:
|
||||
self.model.load_autosplit(self.cache)
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
self.generator = ExLlamaV2BaseGenerator(self.model, self.cache, self.tokenizer)
|
||||
def generate_text(self, prompt: str, max_new_tokens: int = 150,seed: int = random.randint(0,999999) ):
|
||||
|
||||
def generate_text(self,
|
||||
prompt: str,
|
||||
max_tokens: int = 150,
|
||||
temperature=0.5,
|
||||
seed: int = random.randint(0, 999999),
|
||||
token_repetition_penalty: float = 1.0,
|
||||
stop: list = None):
|
||||
try:
|
||||
self.generator.warmup()
|
||||
time_begin = time.time()
|
||||
settings = ExLlamaV2Sampler.Settings()
|
||||
settings.token_repetition_penalty = token_repetition_penalty
|
||||
|
||||
if stop:
|
||||
settings.stop_sequence = stop
|
||||
|
||||
output = self.generator.generate_simple(
|
||||
prompt, ExLlamaV2Sampler.Settings(), max_new_tokens, seed=seed
|
||||
prompt, settings, max_tokens, seed=seed
|
||||
)
|
||||
time_end = time.time()
|
||||
time_total = time_end - time_begin
|
||||
|
||||
33
main.py
33
main.py
@@ -7,13 +7,19 @@ from uvicorn import run
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize the modelManager with a default model path
|
||||
default_model_path = "~/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
|
||||
default_model_path = "/home/david/Models/SynthIA-7B-v2.0-5.0bpw-h6-exl2"
|
||||
modelManager = ModelManager(default_model_path)
|
||||
|
||||
print(output)
|
||||
class TextRequest(BaseModel):
|
||||
model: str
|
||||
messages: list[dict]
|
||||
temperature: float
|
||||
model: str = None # Make the "model" field optional with a default value of None
|
||||
prompt: str
|
||||
max_tokens: int = 200
|
||||
temperature: float = 1
|
||||
top_p: float = 0.9
|
||||
seed: int = 10
|
||||
stream: bool = False
|
||||
token_repetition_penalty: float = 1.0
|
||||
stop: list = None
|
||||
|
||||
class TextResponse(BaseModel):
|
||||
response: str
|
||||
@@ -23,20 +29,9 @@ class TextResponse(BaseModel):
|
||||
def generate_text(request: TextRequest):
|
||||
global modelManager
|
||||
try:
|
||||
model_path = request.model
|
||||
|
||||
if model_path and model_path != modelManager.config.model_path:
|
||||
# Check if the specified model path exists
|
||||
if not os.path.exists(model_path):
|
||||
raise HTTPException(status_code=400, detail="Model path does not exist")
|
||||
|
||||
# Reinitialize the modelManager with the new model path
|
||||
modelManager = ModelManager(model_path)
|
||||
|
||||
messages = request.messages
|
||||
user_message = next(msg["content"] for msg in messages if msg["role"] == "user")
|
||||
|
||||
output, generation_time = modelManager.generate_text(user_message)
|
||||
prompt = request.prompt # Get the prompt from the request
|
||||
user_message = prompt # Assuming that prompt is equivalent to the user's message
|
||||
output, generation_time = modelManager.generate_text(prompt=user_message)
|
||||
return {"response": output, "generation_time": generation_time}
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user