mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-25 15:07:31 +00:00
Compare commits
1 Commits
master
...
feature/do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f7b77f341 |
507
app/download_manager.py
Normal file
507
app/download_manager.py
Normal file
@@ -0,0 +1,507 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
|
||||
import folder_paths
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from server import PromptServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALLOWED_HTTPS_HOSTS = frozenset({
|
||||
"huggingface.co",
|
||||
"cdn-lfs.huggingface.co",
|
||||
"cdn-lfs-us-1.huggingface.co",
|
||||
"cdn-lfs-eu-1.huggingface.co",
|
||||
"civitai.com",
|
||||
"api.civitai.com",
|
||||
})
|
||||
|
||||
ALLOWED_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||
|
||||
MAX_CONCURRENT_DOWNLOADS = 3
|
||||
MAX_TERMINAL_TASKS = 50
|
||||
MAX_REDIRECTS = 10
|
||||
|
||||
DOWNLOAD_TEMP_SUFFIX = ".download_tmp"
|
||||
DOWNLOAD_META_SUFFIX = ".download_meta"
|
||||
|
||||
|
||||
class DownloadStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
DOWNLOADING = "downloading"
|
||||
PAUSED = "paused"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
ACTIVE_STATUSES = frozenset({
|
||||
DownloadStatus.PENDING,
|
||||
DownloadStatus.DOWNLOADING,
|
||||
DownloadStatus.PAUSED,
|
||||
})
|
||||
|
||||
TERMINAL_STATUSES = frozenset({
|
||||
DownloadStatus.COMPLETED,
|
||||
DownloadStatus.ERROR,
|
||||
DownloadStatus.CANCELLED,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadTask:
|
||||
id: str
|
||||
url: str
|
||||
filename: str
|
||||
directory: str
|
||||
save_path: str
|
||||
temp_path: str
|
||||
meta_path: str
|
||||
status: DownloadStatus = DownloadStatus.PENDING
|
||||
progress: float = 0.0
|
||||
received_bytes: int = 0
|
||||
total_bytes: int = 0
|
||||
speed_bytes_per_sec: float = 0.0
|
||||
eta_seconds: float = 0.0
|
||||
error: Optional[str] = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
client_id: Optional[str] = None
|
||||
_worker: Optional[asyncio.Task] = field(default=None, repr=False)
|
||||
_stop_reason: Optional[str] = field(default=None, repr=False)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"url": self.url,
|
||||
"filename": self.filename,
|
||||
"directory": self.directory,
|
||||
"status": self.status.value,
|
||||
"progress": self.progress,
|
||||
"received_bytes": self.received_bytes,
|
||||
"total_bytes": self.total_bytes,
|
||||
"speed_bytes_per_sec": self.speed_bytes_per_sec,
|
||||
"eta_seconds": self.eta_seconds,
|
||||
"error": self.error,
|
||||
"created_at": self.created_at,
|
||||
}
|
||||
|
||||
|
||||
class DownloadManager:
|
||||
def __init__(self, server: PromptServer):
|
||||
self.server = server
|
||||
self.tasks: dict[str, DownloadTask] = {}
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._semaphore = asyncio.Semaphore(MAX_CONCURRENT_DOWNLOADS)
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
if self._session is None or self._session.closed:
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=30, sock_read=60)
|
||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
workers = [t._worker for t in self.tasks.values() if t._worker and not t._worker.done()]
|
||||
for w in workers:
|
||||
w.cancel()
|
||||
if workers:
|
||||
await asyncio.gather(*workers, return_exceptions=True)
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
# -- Validation --
|
||||
|
||||
@staticmethod
|
||||
def _validate_url(url: str) -> Optional[str]:
|
||||
try:
|
||||
parts = urlsplit(url)
|
||||
except Exception:
|
||||
return "Invalid URL"
|
||||
|
||||
if parts.username or parts.password:
|
||||
return "Credentials in URL are not allowed"
|
||||
|
||||
host = (parts.hostname or "").lower()
|
||||
scheme = parts.scheme.lower()
|
||||
|
||||
if scheme != "https":
|
||||
return "Only HTTPS URLs are allowed"
|
||||
|
||||
if host not in ALLOWED_HTTPS_HOSTS:
|
||||
return f"Host '{host}' is not in the allowed list"
|
||||
|
||||
if parts.port not in (None, 443):
|
||||
return "Custom ports are not allowed for remote downloads"
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_filename(filename: str) -> Optional[str]:
|
||||
if not filename:
|
||||
return "Filename must not be empty"
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
return f"File extension '{ext}' not allowed. Allowed: {', '.join(sorted(ALLOWED_EXTENSIONS))}"
|
||||
if os.path.sep in filename or (os.path.altsep and os.path.altsep in filename):
|
||||
return "Filename must not contain path separators"
|
||||
if ".." in filename:
|
||||
return "Filename must not contain '..'"
|
||||
for ch in filename:
|
||||
if ord(ch) < 32:
|
||||
return "Filename must not contain control characters"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _validate_directory(directory: str) -> Optional[str]:
|
||||
if directory not in folder_paths.folder_names_and_paths:
|
||||
valid = ', '.join(sorted(folder_paths.folder_names_and_paths.keys()))
|
||||
return f"Unknown model directory '{directory}'. Valid directories: {valid}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _resolve_save_path(directory: str, filename: str) -> tuple[str, str, str]:
|
||||
"""Returns (save_path, temp_path, meta_path) for a download."""
|
||||
paths = folder_paths.folder_names_and_paths[directory][0]
|
||||
base_dir = paths[0]
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
|
||||
save_path = os.path.join(base_dir, filename)
|
||||
temp_path = save_path + DOWNLOAD_TEMP_SUFFIX
|
||||
meta_path = save_path + DOWNLOAD_META_SUFFIX
|
||||
|
||||
real_save = os.path.realpath(save_path)
|
||||
real_base = os.path.realpath(base_dir)
|
||||
if os.path.commonpath([real_save, real_base]) != real_base:
|
||||
raise ValueError("Resolved path escapes the model directory")
|
||||
|
||||
return save_path, temp_path, meta_path
|
||||
|
||||
# -- Sidecar metadata for resume validation --
|
||||
|
||||
@staticmethod
|
||||
def _write_meta(meta_path: str, url: str, task_id: str):
|
||||
try:
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump({"url": url, "task_id": task_id}, f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _read_meta(meta_path: str) -> Optional[dict]:
|
||||
try:
|
||||
with open(meta_path, "r") as f:
|
||||
return json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_files(*paths: str):
|
||||
for p in paths:
|
||||
try:
|
||||
if os.path.exists(p):
|
||||
os.remove(p)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# -- Task management --
|
||||
|
||||
def _prune_terminal_tasks(self):
|
||||
terminal = [
|
||||
(tid, t) for tid, t in self.tasks.items()
|
||||
if t.status in TERMINAL_STATUSES
|
||||
]
|
||||
if len(terminal) > MAX_TERMINAL_TASKS:
|
||||
terminal.sort(key=lambda x: x[1].created_at)
|
||||
to_remove = len(terminal) - MAX_TERMINAL_TASKS
|
||||
for tid, _ in terminal[:to_remove]:
|
||||
del self.tasks[tid]
|
||||
|
||||
async def start_download(
|
||||
self, url: str, directory: str, filename: str, client_id: Optional[str] = None
|
||||
) -> tuple[Optional[DownloadTask], Optional[str]]:
|
||||
err = self._validate_url(url)
|
||||
if err:
|
||||
return None, err
|
||||
|
||||
err = self._validate_filename(filename)
|
||||
if err:
|
||||
return None, err
|
||||
|
||||
err = self._validate_directory(directory)
|
||||
if err:
|
||||
return None, err
|
||||
|
||||
try:
|
||||
save_path, temp_path, meta_path = self._resolve_save_path(directory, filename)
|
||||
except ValueError as e:
|
||||
return None, str(e)
|
||||
|
||||
if os.path.exists(save_path):
|
||||
return None, f"File already exists: {directory}/{filename}"
|
||||
|
||||
# Reject duplicate active download by URL
|
||||
for task in self.tasks.values():
|
||||
if task.url == url and task.status in ACTIVE_STATUSES:
|
||||
return None, f"Download already in progress for this URL (id: {task.id})"
|
||||
|
||||
# Reject duplicate active download by destination path (#4)
|
||||
for task in self.tasks.values():
|
||||
if task.save_path == save_path and task.status in ACTIVE_STATUSES:
|
||||
return None, f"Download already in progress for {directory}/{filename} (id: {task.id})"
|
||||
|
||||
# Clean stale temp/meta if no active task owns them (#9)
|
||||
existing_meta = self._read_meta(meta_path)
|
||||
if existing_meta:
|
||||
owning_task = self.tasks.get(existing_meta.get("task_id", ""))
|
||||
if not owning_task or owning_task.status in TERMINAL_STATUSES:
|
||||
if existing_meta.get("url") != url:
|
||||
self._cleanup_files(temp_path, meta_path)
|
||||
|
||||
task = DownloadTask(
|
||||
id=uuid.uuid4().hex[:12],
|
||||
url=url,
|
||||
filename=filename,
|
||||
directory=directory,
|
||||
save_path=save_path,
|
||||
temp_path=temp_path,
|
||||
meta_path=meta_path,
|
||||
client_id=client_id,
|
||||
)
|
||||
self.tasks[task.id] = task
|
||||
self._prune_terminal_tasks()
|
||||
|
||||
task._worker = asyncio.create_task(self._run_download(task))
|
||||
return task, None
|
||||
|
||||
# -- Redirect-safe fetch (#1, #2, #3) --
|
||||
|
||||
async def _fetch_with_validated_redirects(
|
||||
self, session: aiohttp.ClientSession, url: str, headers: dict
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""Follow redirects manually, validating each hop against the allowlist."""
|
||||
current_url = url
|
||||
for _ in range(MAX_REDIRECTS + 1):
|
||||
resp = await session.get(current_url, headers=headers, allow_redirects=False)
|
||||
if resp.status not in (301, 302, 303, 307, 308):
|
||||
return resp
|
||||
|
||||
location = resp.headers.get("Location")
|
||||
await resp.release()
|
||||
if not location:
|
||||
raise ValueError("Redirect without Location header")
|
||||
|
||||
resolved = URL(current_url).join(URL(location))
|
||||
current_url = str(resolved)
|
||||
|
||||
# Validate the redirect target host
|
||||
parts = urlsplit(current_url)
|
||||
host = (parts.hostname or "").lower()
|
||||
scheme = parts.scheme.lower()
|
||||
|
||||
if scheme != "https":
|
||||
raise ValueError(f"Redirect to non-HTTPS URL: {current_url}")
|
||||
if host not in ALLOWED_HTTPS_HOSTS:
|
||||
# Allow CDN hosts that HuggingFace/CivitAI commonly redirect to
|
||||
raise ValueError(f"Redirect to disallowed host: {host}")
|
||||
|
||||
# 303 means GET with no Range
|
||||
if resp.status == 303:
|
||||
headers = {k: v for k, v in headers.items() if k.lower() != "range"}
|
||||
|
||||
raise ValueError(f"Too many redirects (>{MAX_REDIRECTS})")
|
||||
|
||||
# -- Download worker --
|
||||
|
||||
async def _run_download(self, task: DownloadTask):
|
||||
try:
|
||||
async with self._semaphore:
|
||||
await self._run_download_inner(task)
|
||||
except asyncio.CancelledError:
|
||||
if task._stop_reason == "pause":
|
||||
task.status = DownloadStatus.PAUSED
|
||||
task.speed_bytes_per_sec = 0
|
||||
task.eta_seconds = 0
|
||||
await self._send_progress(task)
|
||||
else:
|
||||
task.status = DownloadStatus.CANCELLED
|
||||
await self._send_progress(task)
|
||||
self._cleanup_files(task.temp_path, task.meta_path)
|
||||
except Exception as e:
|
||||
task.status = DownloadStatus.ERROR
|
||||
task.error = str(e)
|
||||
await self._send_progress(task)
|
||||
logger.exception("Download error for %s", task.url)
|
||||
|
||||
async def _run_download_inner(self, task: DownloadTask):
|
||||
session = await self._get_session()
|
||||
headers = {}
|
||||
|
||||
# Resume support with sidecar validation (#9)
|
||||
if os.path.exists(task.temp_path):
|
||||
meta = self._read_meta(task.meta_path)
|
||||
if meta and meta.get("url") == task.url:
|
||||
existing_size = os.path.getsize(task.temp_path)
|
||||
if existing_size > 0:
|
||||
headers["Range"] = f"bytes={existing_size}-"
|
||||
task.received_bytes = existing_size
|
||||
else:
|
||||
self._cleanup_files(task.temp_path, task.meta_path)
|
||||
|
||||
self._write_meta(task.meta_path, task.url, task.id)
|
||||
task.status = DownloadStatus.DOWNLOADING
|
||||
await self._send_progress(task)
|
||||
|
||||
resp = await self._fetch_with_validated_redirects(session, task.url, headers)
|
||||
try:
|
||||
if resp.status == 416:
|
||||
content_range = resp.headers.get("Content-Range", "")
|
||||
if content_range:
|
||||
total_str = content_range.split("/")[-1]
|
||||
if total_str != "*":
|
||||
total = int(total_str)
|
||||
if task.received_bytes >= total:
|
||||
if not os.path.exists(task.save_path):
|
||||
os.rename(task.temp_path, task.save_path)
|
||||
self._cleanup_files(task.meta_path)
|
||||
task.status = DownloadStatus.COMPLETED
|
||||
task.progress = 1.0
|
||||
task.total_bytes = total
|
||||
await self._send_progress(task)
|
||||
return
|
||||
raise ValueError(f"HTTP 416 Range Not Satisfiable")
|
||||
|
||||
if resp.status not in (200, 206):
|
||||
task.status = DownloadStatus.ERROR
|
||||
task.error = f"HTTP {resp.status}"
|
||||
await self._send_progress(task)
|
||||
return
|
||||
|
||||
if resp.status == 200:
|
||||
task.received_bytes = 0
|
||||
|
||||
content_length = resp.content_length
|
||||
if resp.status == 206 and content_length:
|
||||
task.total_bytes = task.received_bytes + content_length
|
||||
elif resp.status == 200 and content_length:
|
||||
task.total_bytes = content_length
|
||||
|
||||
mode = "ab" if resp.status == 206 else "wb"
|
||||
speed_window_start = time.monotonic()
|
||||
speed_window_bytes = 0
|
||||
last_progress_time = 0.0
|
||||
|
||||
with open(task.temp_path, mode) as f:
|
||||
async for chunk in resp.content.iter_chunked(1024 * 64):
|
||||
f.write(chunk)
|
||||
task.received_bytes += len(chunk)
|
||||
speed_window_bytes += len(chunk)
|
||||
|
||||
now = time.monotonic()
|
||||
elapsed = now - speed_window_start
|
||||
if elapsed > 0.5:
|
||||
task.speed_bytes_per_sec = speed_window_bytes / elapsed
|
||||
if task.total_bytes > 0 and task.speed_bytes_per_sec > 0:
|
||||
remaining = task.total_bytes - task.received_bytes
|
||||
task.eta_seconds = remaining / task.speed_bytes_per_sec
|
||||
speed_window_start = now
|
||||
speed_window_bytes = 0
|
||||
|
||||
if task.total_bytes > 0:
|
||||
task.progress = task.received_bytes / task.total_bytes
|
||||
|
||||
if now - last_progress_time >= 0.25:
|
||||
await self._send_progress(task)
|
||||
last_progress_time = now
|
||||
finally:
|
||||
resp.release()
|
||||
|
||||
# Final cancel check before committing (#7)
|
||||
if task._stop_reason is not None:
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
# Re-check destination before finalizing (#10)
|
||||
if os.path.exists(task.save_path):
|
||||
task.status = DownloadStatus.ERROR
|
||||
task.error = f"Destination file appeared during download: {task.directory}/{task.filename}"
|
||||
await self._send_progress(task)
|
||||
return
|
||||
|
||||
os.replace(task.temp_path, task.save_path)
|
||||
self._cleanup_files(task.meta_path)
|
||||
task.status = DownloadStatus.COMPLETED
|
||||
task.progress = 1.0
|
||||
task.speed_bytes_per_sec = 0
|
||||
task.eta_seconds = 0
|
||||
await self._send_progress(task)
|
||||
logger.info("Download complete: %s/%s", task.directory, task.filename)
|
||||
|
||||
# -- Progress (#8, #14) --
|
||||
|
||||
async def _send_progress(self, task: DownloadTask):
|
||||
try:
|
||||
self.server.send_sync("download_progress", task.to_dict(), task.client_id)
|
||||
except Exception:
|
||||
logger.exception("Failed to send download progress event")
|
||||
|
||||
# -- Control operations (#5, #6, #13) --
|
||||
|
||||
def pause_download(self, task_id: str) -> Optional[str]:
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
return "Download not found"
|
||||
if task.status not in (DownloadStatus.PENDING, DownloadStatus.DOWNLOADING):
|
||||
return f"Cannot pause download in state '{task.status.value}'"
|
||||
task._stop_reason = "pause"
|
||||
if task._worker and not task._worker.done():
|
||||
task._worker.cancel()
|
||||
return None
|
||||
|
||||
def resume_download(self, task_id: str) -> Optional[str]:
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
return "Download not found"
|
||||
if task.status != DownloadStatus.PAUSED:
|
||||
return f"Cannot resume download in state '{task.status.value}'"
|
||||
task._stop_reason = None
|
||||
task.status = DownloadStatus.PENDING
|
||||
task._worker = asyncio.create_task(self._run_download(task))
|
||||
return None
|
||||
|
||||
def cancel_download(self, task_id: str) -> Optional[str]:
|
||||
task = self.tasks.get(task_id)
|
||||
if not task:
|
||||
return "Download not found"
|
||||
if task.status in TERMINAL_STATUSES:
|
||||
return f"Cannot cancel download in state '{task.status.value}'"
|
||||
task._stop_reason = "cancel"
|
||||
if task._worker and not task._worker.done():
|
||||
task._worker.cancel()
|
||||
else:
|
||||
task.status = DownloadStatus.CANCELLED
|
||||
self._cleanup_files(task.temp_path, task.meta_path)
|
||||
return None
|
||||
|
||||
# -- Query --
|
||||
|
||||
def get_all_tasks(self, client_id: Optional[str] = None) -> list[dict]:
|
||||
tasks = self.tasks.values()
|
||||
if client_id is not None:
|
||||
tasks = [t for t in tasks if t.client_id == client_id]
|
||||
return [t.to_dict() for t in tasks]
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[dict]:
|
||||
task = self.tasks.get(task_id)
|
||||
return task.to_dict() if task else None
|
||||
@@ -224,6 +224,8 @@ parser.add_argument("--user-directory", type=is_valid_directory, default=None, h
|
||||
|
||||
parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
|
||||
|
||||
parser.add_argument("--enable-download-api", action="store_true", help="Enable the model download API. When set, ComfyUI exposes endpoints that allow downloading model files directly into the models directory. Only HTTPS downloads from allowed hosts (huggingface.co, civitai.com) are permitted.")
|
||||
|
||||
parser.add_argument(
|
||||
"--comfy-api-base",
|
||||
type=str,
|
||||
|
||||
@@ -16,6 +16,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
"download_api": args.enable_download_api,
|
||||
}
|
||||
|
||||
|
||||
|
||||
84
server.py
84
server.py
@@ -43,6 +43,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from app.download_manager import DownloadManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@@ -205,6 +206,7 @@ class PromptServer():
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.download_manager = DownloadManager(self) if args.enable_download_api else None
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
self.loop = loop
|
||||
@@ -1028,9 +1030,91 @@ class PromptServer():
|
||||
|
||||
return web.Response(status=200)
|
||||
|
||||
# -- Download API (gated behind --enable-download-api) --
|
||||
|
||||
def _require_download_api(handler):
|
||||
async def wrapper(request):
|
||||
if self.download_manager is None:
|
||||
return web.json_response(
|
||||
{"error": "Download API is not enabled. Start ComfyUI with --enable-download-api."},
|
||||
status=403,
|
||||
)
|
||||
return await handler(request)
|
||||
return wrapper
|
||||
|
||||
@routes.post("/download/model")
|
||||
@_require_download_api
|
||||
async def post_download_model(request):
|
||||
json_data = await request.json()
|
||||
url = json_data.get("url")
|
||||
directory = json_data.get("directory")
|
||||
filename = json_data.get("filename")
|
||||
client_id = json_data.get("client_id")
|
||||
|
||||
if not url or not directory or not filename:
|
||||
return web.json_response(
|
||||
{"error": "Missing required fields: url, directory, filename"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
task, err = await self.download_manager.start_download(url, directory, filename, client_id=client_id)
|
||||
if err:
|
||||
status = 409 if "already" in err.lower() else 400
|
||||
return web.json_response({"error": err}, status=status)
|
||||
|
||||
return web.json_response(task.to_dict(), status=201)
|
||||
|
||||
@routes.get("/download/status")
|
||||
@_require_download_api
|
||||
async def get_download_status(request):
|
||||
client_id = request.rel_url.query.get("client_id")
|
||||
return web.json_response(self.download_manager.get_all_tasks(client_id=client_id))
|
||||
|
||||
@routes.get("/download/status/{task_id}")
|
||||
@_require_download_api
|
||||
async def get_download_task_status(request):
|
||||
task_id = request.match_info["task_id"]
|
||||
task_data = self.download_manager.get_task(task_id)
|
||||
if task_data is None:
|
||||
return web.json_response({"error": "Download not found"}, status=404)
|
||||
return web.json_response(task_data)
|
||||
|
||||
@routes.post("/download/pause/{task_id}")
|
||||
@_require_download_api
|
||||
async def post_download_pause(request):
|
||||
task_id = request.match_info["task_id"]
|
||||
err = self.download_manager.pause_download(task_id)
|
||||
if err:
|
||||
return web.json_response({"error": err}, status=400)
|
||||
return web.json_response({"status": "paused"})
|
||||
|
||||
@routes.post("/download/resume/{task_id}")
|
||||
@_require_download_api
|
||||
async def post_download_resume(request):
|
||||
task_id = request.match_info["task_id"]
|
||||
err = self.download_manager.resume_download(task_id)
|
||||
if err:
|
||||
return web.json_response({"error": err}, status=400)
|
||||
return web.json_response({"status": "resumed"})
|
||||
|
||||
@routes.post("/download/cancel/{task_id}")
|
||||
@_require_download_api
|
||||
async def post_download_cancel(request):
|
||||
task_id = request.match_info["task_id"]
|
||||
err = self.download_manager.cancel_download(task_id)
|
||||
if err:
|
||||
return web.json_response({"error": err}, status=400)
|
||||
return web.json_response({"status": "cancelled"})
|
||||
|
||||
async def setup(self):
|
||||
timeout = aiohttp.ClientTimeout(total=None) # no timeout
|
||||
self.client_session = aiohttp.ClientSession(timeout=timeout)
|
||||
if self.download_manager is not None:
|
||||
self.app.on_cleanup.append(self._cleanup_download_manager)
|
||||
|
||||
async def _cleanup_download_manager(self, app):
|
||||
if self.download_manager is not None:
|
||||
await self.download_manager.close()
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
|
||||
Reference in New Issue
Block a user