Tree: Switch to asynchronous file handling

Using aiofiles, there's no longer a possiblity of blocking file operations
that can hang up the event loop. In addition, partially migrate
classes to use asynchronous init instead of the normal python magic method.

The only exception is config, since that's handled in the synchonous
init before the event loop starts.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-09-10 16:45:14 -04:00
parent 54bfb770af
commit 2c3bc71afa
9 changed files with 63 additions and 36 deletions

View File

@@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import aiofiles
import json
import pathlib
from importlib.metadata import version as package_version
@@ -110,7 +111,7 @@ class PromptTemplate:
self.template = self.compile(raw_template)
@classmethod
def from_file(self, template_path: pathlib.Path):
async def from_file(self, template_path: pathlib.Path):
"""Get a template from a jinja file."""
# Add the jinja extension if it isn't provided
@@ -121,10 +122,13 @@ class PromptTemplate:
template_path = template_path.with_suffix(".jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
async with aiofiles.open(
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return PromptTemplate(
name=template_name,
raw_template=raw_template_stream.read(),
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
@@ -133,15 +137,16 @@ class PromptTemplate:
)
@classmethod
def from_model_json(
async def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)
if not chat_template: