Add model downloading endpoint.

This commit is contained in:
Robin Huang
2024-08-06 18:07:32 -07:00
parent b334605a66
commit 6976ccc5ca
6 changed files with 177 additions and 1 deletions

View File

@@ -12,7 +12,6 @@ import json
import glob
import struct
import ssl
import hashlib
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
@@ -28,6 +27,7 @@ import comfy.model_management
import node_helpers
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from model_filemanager import download_model, DownloadStatus
class BinaryEventTypes:
@@ -76,6 +76,8 @@ class PromptServer():
self.prompt_queue = None
self.loop = loop
self.messages = asyncio.Queue()
timeout = aiohttp.ClientTimeout(total=None) # no timeout
self.client_session = aiohttp.ClientSession(timeout=timeout)
self.number = 0
middlewares = [cache_control]
@@ -559,6 +561,28 @@ class PromptServer():
self.prompt_queue.delete_history_item(id_to_delete)
return web.Response(status=200)
@routes.post("/download")
async def download_handler(request):
async def report_progress(filename: str, status: DownloadStatus):
await self.send_json(filename, {
"progress_percentage": status.progress_percentage,
"status": status.status,
"message": status.message
})
data = await request.json()
url = data.get('url')
model_directory = data.get('model_directory')
model_filename = data.get('model_filename')
if not url or not model_directory or not model_filename:
return web.json_response({"status": "error", "message": "Missing URL or folder path or filename"}, status=400)
task = asyncio.create_task(download_model(self.client_session, model_filename, url, model_directory, report_progress))
await task
return web.Response(status=200)
def add_routes(self):
self.user_manager.add_routes(self.routes)
@@ -698,3 +722,6 @@ class PromptServer():
logging.warning(traceback.format_exc())
return json_data
def close_session(self):
self.client_session.close()