From 0eea1acc6e2b2675dce113988ec2e6a22e3493c8 Mon Sep 17 00:00:00 2001 From: altoiddealer Date: Thu, 1 Aug 2024 15:48:30 -0400 Subject: [PATCH] Restore '/controlnet/control_types' API endpoint (#912) Restores the '/controlnet/control_types' API endpoint, which is immensely useful for anyone using ControlNet via the API ## Description I recently opened an Issue on the main ControlNet extension repo Mikubill/sd-webui-controlnet#2737 suggesting that they add a new API endpoint to allow users to retrieve filtered data based on a Control Type, just like in the UI. I was both shocked and immensely disappointed when they finally replied, stating that the endpoint does already exist! I understand that Forge is a massive overhaul to A1111, so perhaps this aspect was removed to get everything working, and then just never reimplemented. Whatever the case, this endpoint is truly amazing for using ControlNet via API. With only the 'models' and 'modules' endpoints, how the heck is someone to make a dynamic script? They would essentially have to take a fat chunk of existing ControlNet code, plus these few added functions, just to filter the data appropriately. I'm an amateur coder, at best, however I'm quite confident about this implementation. This uses your existing functions as best as possible, I believe, including your filter list and the check for currently loaded SD model version. Please merge this. Thank you ## Screenshots/videos: Restored [response_1714160176770.json](https://github.com/lllyasviel/stable-diffusion-webui-forge/files/15134692/response_1714160176770.json) ## Checklist: - [X] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) - [X] I have performed a self-review of my own code - [X] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style) - [X] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests) --- .../sd_forge_controlnet/lib_controlnet/api.py | 26 +++++++++++ .../lib_controlnet/global_state.py | 46 +++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py index 29e7d662..8bd43ae8 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/api.py @@ -11,6 +11,8 @@ from .global_state import ( get_all_preprocessor_names, get_all_controlnet_names, get_preprocessor, + get_all_preprocessor_tags, + select_control_type, ) from .utils import judge_image_type from .logging import logger @@ -53,6 +55,30 @@ def controlnet_api(_: gr.Blocks, app: FastAPI): # "module_detail": external_code.get_modules_detail(alias_names), } + @app.get("/controlnet/control_types") + async def control_types(): + def format_control_type( + filtered_preprocessor_list, + filtered_model_list, + default_option, + default_model, + ): + control_dict = { + "module_list": filtered_preprocessor_list, + "model_list": filtered_model_list, + "default_option": default_option, + "default_model": default_model, + } + + return control_dict + + return { + "control_types": { + control_type: format_control_type(*select_control_type(control_type)) + for control_type in get_all_preprocessor_tags() + } + } + @app.post("/controlnet/detect") async def detect( controlnet_module: str = Body("none", title="Controlnet Module"), diff --git a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py index dc87dda8..15bef936 100644 --- a/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py +++ b/extensions-builtin/sd_forge_controlnet/lib_controlnet/global_state.py @@ -6,6 +6,7 @@ from modules import shared, sd_models from lib_controlnet.enums import StableDiffusionVersion from modules_forge.shared import controlnet_dir, supported_preprocessors +from typing import Dict, Tuple, List CN_MODEL_EXTS = [".pt", ".pth", ".ckpt", ".safetensors", ".bin", ".patch"] @@ -56,6 +57,10 @@ controlnet_names = ['None'] def get_preprocessor(name): return supported_preprocessors.get(name, None) +def get_default_preprocessor(tag): + ps = get_filtered_preprocessor_names(tag) + assert len(ps) > 0 + return ps[0] if len(ps) == 1 else ps[1] def get_sorted_preprocessors(): preprocessors = [p for k, p in supported_preprocessors.items() if k != 'None'] @@ -144,3 +149,44 @@ def get_sd_version() -> StableDiffusionVersion: return StableDiffusionVersion.SD1x else: return StableDiffusionVersion.UNKNOWN + + +def select_control_type( + control_type: str, + sd_version: StableDiffusionVersion = StableDiffusionVersion.UNKNOWN, +) -> Tuple[List[str], List[str], str, str]: + global controlnet_names + + pattern = control_type.lower() + all_models = list(controlnet_names) + + if pattern == "all": + preprocessors = get_sorted_preprocessors().values() + return [ + [p.name for p in preprocessors], + all_models, + 'none', # default option + "None" # default model + ] + + filtered_model_list = get_filtered_controlnet_names(control_type) + + if pattern == "none": + filtered_model_list.append("None") + + assert len(filtered_model_list) > 0, "'None' model should always be available." + if len(filtered_model_list) == 1: + default_model = "None" + else: + default_model = filtered_model_list[1] + for x in filtered_model_list: + if "11" in x.split("[")[0]: + default_model = x + break + + return ( + get_filtered_preprocessor_names(control_type), + filtered_model_list, + get_default_preprocessor(control_type), + default_model + )