diff --git a/common/downloader.py b/common/downloader.py new file mode 100644 index 0000000..9d4b93f --- /dev/null +++ b/common/downloader.py @@ -0,0 +1,150 @@ +import aiofiles +import aiohttp +import asyncio +import math +import pathlib +import shutil +from huggingface_hub import HfApi, hf_hub_url +from loguru import logger +from rich.progress import Progress +from typing import Optional + +from common.config import lora_config, model_config +from common.logger import get_progress_bar +from common.utils import unwrap + + +async def _download_file( + session: aiohttp.ClientSession, + repo_item: dict, + token: Optional[str], + download_path: pathlib.Path, + chunk_limit: int, + progress: Progress, +): + """Downloads a repo from HuggingFace.""" + + filename = repo_item.get("filename") + url = repo_item.get("url") + + # Default is 2MB + chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000) + + filepath = download_path / filename + filepath.parent.mkdir(parents=True, exist_ok=True) + + req_headers = {"Authorization": f"Bearer {token}"} if token else {} + + async with session.get(url, headers=req_headers) as response: + # TODO: Change to raise errors + assert response.status == 200 + + file_size = int(response.headers["Content-Length"]) + + download_task = progress.add_task( + f"[cyan]Downloading {filename}", total=file_size + ) + + # Chunk limit is 2 MB + async with aiofiles.open(str(filepath), "wb") as f: + async for chunk in response.content.iter_chunked(chunk_limit_bytes): + await f.write(chunk) + progress.update(download_task, advance=len(chunk)) + + +# Huggingface does not know how async works +def _get_repo_info(repo_id, revision, token): + """Fetches information about a HuggingFace repository.""" + + api_client = HfApi() + repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token) + return list( + map( + lambda filename: { + "filename": filename, + "url": hf_hub_url(repo_id, filename, revision=revision), + }, + repo_tree, + ) + ) + + +def _get_download_folder(repo_id: str, repo_type: str, folder_name: Optional[str]): + """Gets the download folder for the repo.""" + + if repo_type == "lora": + download_path = pathlib.Path(unwrap(lora_config().get("lora_dir"), "loras")) + else: + download_path = pathlib.Path(unwrap(model_config().get("model_dir"), "models")) + + download_path = download_path / unwrap(folder_name, repo_id.split("/")[-1]) + return download_path + + +async def hf_repo_download( + repo_id: str, + folder_name: Optional[str], + revision: Optional[str], + token: Optional[str], + chunk_limit: Optional[float], + repo_type: Optional[str] = "model", +): + """Gets a repo's information from HuggingFace and downloads it locally.""" + + file_list = await asyncio.to_thread(_get_repo_info, repo_id, revision, token) + + # Auto-detect repo type if it isn't provided + if not repo_type: + lora_filter = filter( + lambda repo_item: repo_item.get("filename", "").endswith( + ("adapter_config.json", "adapter_model.bin") + ) + ) + + if lora_filter: + repo_type = "lora" + + download_path = _get_download_folder(repo_id, repo_type, folder_name) + download_path.parent.mkdir(parents=True, exist_ok=True) + + if download_path.exists(): + raise FileExistsError( + f"The path {download_path} already exists. Remove the folder and try again." + ) + + logger.info(f"Saving {repo_id} to {str(download_path)}") + + try: + async with aiohttp.ClientSession() as session: + tasks = [] + logger.info(f"Starting download for {repo_id}") + + progress = get_progress_bar() + progress.start() + + for repo_item in file_list: + tasks.append( + _download_file( + session, + repo_item, + token=token, + download_path=download_path.resolve(), + chunk_limit=chunk_limit, + progress=progress, + ) + ) + + await asyncio.gather(*tasks) + progress.stop() + logger.info(f"Finished download for {repo_id}") + + return download_path + except asyncio.CancelledError: + # Cleanup on cancel + if download_path.is_dir(): + shutil.rmtree(download_path) + else: + download_path.unlink() + + # Stop the progress bar + progress.stop() diff --git a/common/logger.py b/common/logger.py index 82e68dd..f21ab09 100644 --- a/common/logger.py +++ b/common/logger.py @@ -23,6 +23,10 @@ RICH_CONSOLE = Console() LOG_LEVEL = os.getenv("TABBY_LOG_LEVEL", "INFO") +def get_progress_bar(): + return Progress(console=RICH_CONSOLE) + + def get_loading_progress_bar(): """Gets a pre-made progress bar for loading tasks.""" diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 5b67714..5958c8b 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -13,12 +13,14 @@ from common.concurrency import ( call_with_semaphore, generate_with_semaphore, ) +from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import coalesce, unwrap from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest from endpoints.OAI.types.chat_completion import ChatCompletionRequest +from endpoints.OAI.types.download import DownloadRequest, DownloadResponse from endpoints.OAI.types.lora import ( LoraCard, LoraList, @@ -288,6 +290,22 @@ async def unload_sampler_override(): sampling.overrides_from_dict({}) +@router.post("/v1/download", dependencies=[Depends(check_admin_key)]) +async def download_model(request: Request, data: DownloadRequest): + """Downloads a model from HuggingFace.""" + + download_task = asyncio.create_task(hf_repo_download(**data.model_dump())) + + # For now, the downloader and request data are 1:1 + download_path = await run_with_request_disconnect( + request, + download_task, + "Download request cancelled by user. Files have been cleaned up.", + ) + + return DownloadResponse(download_path=str(download_path)) + + # Lora list endpoint @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) diff --git a/endpoints/OAI/types/download.py b/endpoints/OAI/types/download.py new file mode 100644 index 0000000..6ba33d9 --- /dev/null +++ b/endpoints/OAI/types/download.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel +from typing import Optional + + +class DownloadRequest(BaseModel): + """Parameters for a HuggingFace repo download.""" + + repo_id: str + repo_type: Optional[str] = "model" + folder_name: Optional[str] = None + revision: Optional[str] = None + token: Optional[str] = None + chunk_limit: Optional[int] = None + + +class DownloadResponse(BaseModel): + """Response for a download request.""" + + download_path: str diff --git a/pyproject.toml b/pyproject.toml index 940ab11..144971f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,11 @@ dependencies = [ "packaging", "tokenizers", "lm-format-enforcer >= 0.9.6", + "aiofiles", + + # TODO: Maybe move these to a downloader feature? + "aiohttp", + "huggingface_hub", ] [project.urls]