mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 19:21:21 +00:00
Cnet (#22)
* ini * remove shit * Create control_model.py * i * i * Update controlnet_supported.py * Update controlnet_supported.py * Update controlnet_supported.py * i * i * Update controlnet_supported.py * i * Update controlnet_supported.py * remove shits * remove shit * Update global_state.py * i * i * Update legacy_preprocessors.py * Update legacy_preprocessors.py * remove shit * Update batch_hijack.py * remove shit * remove shit * i * i * i * Update external_code.py * Update global_state.py * Update infotext.py * Update utils.py * Update external_code.py * i * i * i * Update controlnet_ui_group.py * remove shit * remove shit * i * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * Update controlnet.py * i * Update global_state.py * Update global_state.py * i * Update global_state.py * Update global_state.py * Update global_state.py * Update global_state.py * Update controlnet_ui_group.py * i * Update global_state.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * i * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py * Update controlnet_ui_group.py
This commit is contained in:
@@ -0,0 +1,135 @@
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules.processing import StableDiffusionProcessing
|
||||
|
||||
from lib_controlnet import external_code
|
||||
from lib_controlnet.logging import logger
|
||||
|
||||
|
||||
def field_to_displaytext(fieldname: str) -> str:
|
||||
return " ".join([word.capitalize() for word in fieldname.split("_")])
|
||||
|
||||
|
||||
def displaytext_to_field(text: str) -> str:
|
||||
return "_".join([word.lower() for word in text.split(" ")])
|
||||
|
||||
|
||||
def parse_value(value: str) -> Union[str, float, int, bool]:
|
||||
if value in ("True", "False"):
|
||||
return value == "True"
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return value # Plain string.
|
||||
|
||||
|
||||
def serialize_unit(unit: external_code.ControlNetUnit) -> str:
|
||||
excluded_fields = (
|
||||
"image",
|
||||
"enabled",
|
||||
# Note: "advanced_weighting" is excluded as it is an API-only field.
|
||||
"advanced_weighting",
|
||||
# Note: "inpaint_crop_image" is img2img inpaint only flag, which does not
|
||||
# provide much information when restoring the unit.
|
||||
"inpaint_crop_input_image",
|
||||
)
|
||||
|
||||
log_value = {
|
||||
field_to_displaytext(field): getattr(unit, field)
|
||||
for field in vars(external_code.ControlNetUnit()).keys()
|
||||
if field not in excluded_fields and getattr(unit, field) != -1
|
||||
# Note: exclude hidden slider values.
|
||||
}
|
||||
if not all("," not in str(v) and ":" not in str(v) for v in log_value.values()):
|
||||
logger.error(f"Unexpected tokens encountered:\n{log_value}")
|
||||
return ""
|
||||
|
||||
return ", ".join(f"{field}: {value}" for field, value in log_value.items())
|
||||
|
||||
|
||||
def parse_unit(text: str) -> external_code.ControlNetUnit:
|
||||
return external_code.ControlNetUnit(
|
||||
enabled=True,
|
||||
**{
|
||||
displaytext_to_field(key): parse_value(value)
|
||||
for item in text.split(",")
|
||||
for (key, value) in (item.strip().split(": "),)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Infotext(object):
|
||||
def __init__(self) -> None:
|
||||
self.infotext_fields: List[Tuple[gr.components.IOComponent, str]] = []
|
||||
self.paste_field_names: List[str] = []
|
||||
|
||||
@staticmethod
|
||||
def unit_prefix(unit_index: int) -> str:
|
||||
return f"ControlNet {unit_index}"
|
||||
|
||||
def register_unit(self, unit_index: int, uigroup) -> None:
|
||||
"""Register the unit's UI group. By regsitering the unit, A1111 will be
|
||||
able to paste values from infotext to IOComponents.
|
||||
|
||||
Args:
|
||||
unit_index: The index of the ControlNet unit
|
||||
uigroup: The ControlNetUiGroup instance that contains all gradio
|
||||
iocomponents.
|
||||
"""
|
||||
unit_prefix = Infotext.unit_prefix(unit_index)
|
||||
for field in vars(external_code.ControlNetUnit()).keys():
|
||||
# Exclude image for infotext.
|
||||
if field == "image":
|
||||
continue
|
||||
|
||||
# Every field in ControlNetUnit should have a cooresponding
|
||||
# IOComponent in ControlNetUiGroup.
|
||||
io_component = getattr(uigroup, field)
|
||||
component_locator = f"{unit_prefix} {field}"
|
||||
self.infotext_fields.append((io_component, component_locator))
|
||||
self.paste_field_names.append(component_locator)
|
||||
|
||||
@staticmethod
|
||||
def write_infotext(
|
||||
units: List[external_code.ControlNetUnit], p: StableDiffusionProcessing
|
||||
):
|
||||
"""Write infotext to `p`."""
|
||||
p.extra_generation_params.update(
|
||||
{
|
||||
Infotext.unit_prefix(i): serialize_unit(unit)
|
||||
for i, unit in enumerate(units)
|
||||
if unit.enabled
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def on_infotext_pasted(infotext: str, results: dict) -> None:
|
||||
"""Parse ControlNet infotext string and write result to `results` dict."""
|
||||
updates = {}
|
||||
for k, v in results.items():
|
||||
if not k.startswith("ControlNet"):
|
||||
continue
|
||||
|
||||
assert isinstance(v, str), f"Expect string but got {v}."
|
||||
try:
|
||||
for field, value in vars(parse_unit(v)).items():
|
||||
if field == "image":
|
||||
continue
|
||||
if value is None:
|
||||
logger.debug(f"InfoText: Skipping {field} because value is None.")
|
||||
continue
|
||||
|
||||
component_locator = f"{k} {field}"
|
||||
updates[component_locator] = value
|
||||
logger.debug(f"InfoText: Setting {component_locator} = {value}")
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}\n{e}"
|
||||
)
|
||||
|
||||
results.update(updates)
|
||||
Reference in New Issue
Block a user