mirror of
https://github.com/theroyallab/tabbyAPI.git
synced 2026-03-14 15:57:27 +00:00
Downloader: Switch to use API sizes
Rather than relying on Content-Length which can be unreliable, ping the API to get file sizes and work from there. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,9 @@ import asyncio
|
||||
import math
|
||||
import pathlib
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
from huggingface_hub.hf_api import RepoFile
|
||||
from fnmatch import fnmatch
|
||||
from loguru import logger
|
||||
from rich.progress import Progress
|
||||
@@ -15,9 +17,16 @@ from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoItem:
|
||||
path: str
|
||||
size: int
|
||||
url: str
|
||||
|
||||
|
||||
async def _download_file(
|
||||
session: aiohttp.ClientSession,
|
||||
repo_item: dict,
|
||||
repo_item: RepoItem,
|
||||
token: Optional[str],
|
||||
download_path: pathlib.Path,
|
||||
chunk_limit: int,
|
||||
@@ -25,8 +34,8 @@ async def _download_file(
|
||||
):
|
||||
"""Downloads a repo from HuggingFace."""
|
||||
|
||||
filename = repo_item.get("filename")
|
||||
url = repo_item.get("url")
|
||||
filename = repo_item.path
|
||||
url = repo_item.url
|
||||
|
||||
# Default is 2MB
|
||||
chunk_limit_bytes = math.ceil(unwrap(chunk_limit, 2000000) * 100000)
|
||||
@@ -46,33 +55,20 @@ async def _download_file(
|
||||
message=f"HTTP {response.status}: {error_text}",
|
||||
)
|
||||
|
||||
# Sometimes, Content-Length can be undefined
|
||||
content_length = response.headers.get("Content-Length")
|
||||
file_size = int(content_length) if content_length else None
|
||||
|
||||
# Create progress task with appropriate total (None for indeterminate)
|
||||
download_task = progress.add_task(
|
||||
f"[cyan]Downloading {filename}", total=file_size
|
||||
f"[cyan]Downloading {filename}", total=repo_item.size
|
||||
)
|
||||
|
||||
# Chunk limit is 2 MB
|
||||
downloaded_size = 0
|
||||
async with aiofiles.open(str(filepath), "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(chunk_limit_bytes):
|
||||
await f.write(chunk)
|
||||
|
||||
# Store and update progress bar
|
||||
downloaded_size += len(chunk)
|
||||
progress.update(download_task, completed=downloaded_size)
|
||||
|
||||
# For indeterminate files, set final total and mark as complete
|
||||
if file_size is None:
|
||||
progress.update(
|
||||
download_task, total=downloaded_size, completed=downloaded_size
|
||||
)
|
||||
progress.update(download_task, advance=len(chunk))
|
||||
|
||||
|
||||
# Huggingface does not know how async works
|
||||
def _get_repo_info(repo_id, revision, token):
|
||||
"""Fetches information about a HuggingFace repository."""
|
||||
|
||||
@@ -81,13 +77,18 @@ def _get_repo_info(repo_id, revision, token):
|
||||
token = token or None
|
||||
|
||||
api_client = HfApi()
|
||||
repo_tree = api_client.list_repo_files(repo_id, revision=revision, token=token)
|
||||
repo_tree = api_client.list_repo_tree(
|
||||
repo_id, revision=revision, token=token, recursive=True
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"filename": filename,
|
||||
"url": hf_hub_url(repo_id, filename, revision=revision),
|
||||
}
|
||||
for filename in repo_tree
|
||||
RepoItem(
|
||||
path=item.path,
|
||||
size=item.size,
|
||||
url=hf_hub_url(repo_id, item.path, revision=revision),
|
||||
)
|
||||
for item in repo_tree
|
||||
if isinstance(item, RepoFile)
|
||||
]
|
||||
|
||||
|
||||
@@ -130,12 +131,13 @@ async def hf_repo_download(
|
||||
# Auto-detect repo type if it isn't provided
|
||||
if not repo_type:
|
||||
lora_filter = filter(
|
||||
lambda repo_item: repo_item.get("filename", "").endswith(
|
||||
lambda repo_item: repo_item.path.endswith(
|
||||
("adapter_config.json", "adapter_model.bin")
|
||||
)
|
||||
),
|
||||
file_list,
|
||||
)
|
||||
|
||||
if lora_filter:
|
||||
if any(lora_filter):
|
||||
repo_type = "lora"
|
||||
|
||||
if include or exclude:
|
||||
@@ -145,9 +147,7 @@ async def hf_repo_download(
|
||||
file_list = [
|
||||
file
|
||||
for file in file_list
|
||||
if _check_exclusions(
|
||||
file.get("filename"), include_patterns, exclude_patterns
|
||||
)
|
||||
if _check_exclusions(file.path, include_patterns, exclude_patterns)
|
||||
]
|
||||
|
||||
if not file_list:
|
||||
|
||||
Reference in New Issue
Block a user