Downloader: Add include and exclude parameters

These both take an array of glob strings to state what files or
directories to include or exclude when parsing the download list.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri
2024-04-30 00:58:54 -04:00
parent c47869c606
commit 21a01741c9
3 changed files with 48 additions and 13 deletions

View File

@@ -294,16 +294,21 @@ async def unload_sampler_override():
async def download_model(request: Request, data: DownloadRequest):
"""Downloads a model from HuggingFace."""
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint

View File

@@ -1,15 +1,17 @@
from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel, Field
from typing import List, Optional
class DownloadRequest(BaseModel):
"""Parameters for a HuggingFace repo download."""
repo_id: str
repo_type: Optional[str] = "model"
repo_type: str = "model"
folder_name: Optional[str] = None
revision: Optional[str] = None
token: Optional[str] = None
include: List[str] = Field(default_factory=list)
exclude: List[str] = Field(default_factory=list)
chunk_limit: Optional[int] = None