Tree: Switch to Pydantic 2

Pydantic 2 has more modern methods and stability compared to Pydantic 1

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2023-12-17 01:06:10 -05:00
committed by Brian Dashore
parent f631dd6ff7
commit 51ca1ff396
8 changed files with 18 additions and 15 deletions

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from gen_logging import LogConfig
@@ -45,6 +45,9 @@ class ModelLoadRequest(BaseModel):
draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel):
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces = [])
model_type: str = "model"
module: int
modules: int

View File

@@ -30,7 +30,7 @@ def load_auth_keys():
try:
with open("api_tokens.yml", "r", encoding = 'utf8') as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
auth_keys = AuthKeys.parse_obj(auth_keys_dict)
auth_keys = AuthKeys.model_validate(auth_keys_dict)
except Exception as _:
new_auth_keys = AuthKeys(
api_key = secrets.token_hex(16),
@@ -39,7 +39,7 @@ def load_auth_keys():
auth_keys = new_auth_keys
with open("api_tokens.yml", "w", encoding = "utf8") as auth_file:
yaml.safe_dump(auth_keys.dict(), auth_file, default_flow_style=False)
yaml.safe_dump(auth_keys.model_dump(), auth_file, default_flow_style=False)
print(
f"Your API key is: {auth_keys.api_key}\n"

View File

@@ -18,7 +18,7 @@ def update_from_dict(options_dict: Dict[str, bool]):
if value is None:
value = False
config = LogConfig.parse_obj(options_dict)
config = LogConfig.model_validate(options_dict)
def broadcast_status():
enabled = []

14
main.py
View File

@@ -117,7 +117,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / data.name
load_data = data.dict()
load_data = data.model_dump()
# TODO: Add API exception if draft directory isn't found
draft_config = unwrap(model_config.get("draft"), {})
@@ -156,7 +156,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
status="finished"
)
yield get_sse_packet(response.json(ensure_ascii = False))
yield get_sse_packet(response.model_dump_json())
# Switch to model progress if the draft model is loaded
if model_container.draft_config:
@@ -171,7 +171,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
status="processing"
)
yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
except Exception as e:
@@ -230,7 +230,7 @@ async def load_lora(data: LoraLoadRequest):
if len(model_container.active_loras) > 0:
model_container.unload(True)
result = model_container.load_loras(lora_dir, **data.dict())
result = model_container.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success = unwrap(result.get("success"), []),
failure = unwrap(result.get("failure"), [])
@@ -281,7 +281,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
completion_tokens,
model_path.name)
yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
except Exception as e:
@@ -334,7 +334,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
model_path.name
)
yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.model_dump_json())
# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
@@ -342,7 +342,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
finish_reason = "stop"
)
yield get_sse_packet(finish_response.json(ensure_ascii=False))
yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
except Exception as e:

View File

@@ -8,7 +8,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1
# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn

View File

@@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1
# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn

View File

@@ -14,7 +14,7 @@ https://github.com/turboderp/exllamav2/releases/download/v0.0.11/exllamav2-0.0.1
# Pip dependencies
fastapi
pydantic < 2,>= 1
pydantic
PyYAML
progress
uvicorn

View File

@@ -26,7 +26,7 @@ def get_generator_error(message: str):
# Log and send the exception
print(f"\n{generator_error.error.trace}")
return get_sse_packet(generator_error.json(ensure_ascii = False))
return get_sse_packet(generator_error.model_dump_json())
def get_sse_packet(json_data: str):
return f"data: {json_data}\n\n"