feat: workflows for formatting/linting (#35)

* add github workflows for pylint and yapf

* yapf

* docstrings for auth

* fix auth.py

* fix generators.py

* fix gen_logging.py

* fix main.py

* fix model.py

* fix templating.py

* fix utils.py

* update formatting.sh to include subdirs for pylint

* fix model_test.py

* fix wheel_test.py

* rename utils to utils_oai

* fix OAI/utils_oai.py

* fix completion.py

* fix token.py

* fix lora.py

* fix common.py

* add pylintrc and fix model.py

* finish up pylint

* fix attribute error

* main.py formatting

* add formatting batch script

* Main: Remove unnecessary global

Linter suggestion.

Signed-off-by: kingbri <bdashore3@proton.me>

* switch to ruff

* Formatting + Linting: Add ruff.toml

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting + Linting: Switch scripts to use ruff

Also remove the file and recent file change functions from both
scripts.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format and lint

Signed-off-by: kingbri <bdashore3@proton.me>

* Scripts + Workflows: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Remove pylint flags

We use ruff now

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting: Line length is 88

Use the same value as Black.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Update to new line length rules.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: kingbri <bdashore3@proton.me>
This commit is contained in:
AlpinDale
2023-12-22 16:20:35 +00:00
committed by GitHub
parent a14abfe21c
commit fa47f51f85
22 changed files with 1210 additions and 511 deletions

View File

@@ -1,46 +1,56 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
# Small replication of AutoTokenizer's chat template system for efficiency
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str
template: str
def get_prompt_from_template(messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None):
def get_prompt_from_template(
messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None,
):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
f"Current version: {version('jinja2')}\n"
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(
messages = messages,
add_generation_prompt = add_generation_prompt,
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_template = jinja_env.from_string(template)
return jinja_template
# Find a matching template name from a model path
def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
@@ -50,14 +60,16 @@ def find_template_from_model(model_path: pathlib.Path):
if template_name in model_name.lower():
return template_name
# Get a template from a jinja file
return None
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding = "utf8") as raw_template:
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
name=prompt_template_name, template=raw_template.read()
)
return None
@@ -66,15 +78,12 @@ def get_template_from_file(prompt_template_name: str):
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists():
with open(json_path, "r", encoding = "utf8") as config_file:
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(
name = name,
template = chat_template
)
return PromptTemplate(name=name, template=chat_template)
return None