mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 18:51:31 +00:00
Update controlnet.py
This commit is contained in:
@@ -166,22 +166,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:
|
||||
return y
|
||||
|
||||
|
||||
class Script(scripts.Script, metaclass=(
|
||||
utils.TimeMeta if logger.level == logging.DEBUG else type)):
|
||||
|
||||
model_cache = OrderedDict()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.latest_network = None
|
||||
self.input_image = None
|
||||
self.latest_model_hash = ""
|
||||
self.enabled_units = []
|
||||
self.detected_map = []
|
||||
self.post_processors = []
|
||||
self.noise_modifier = None
|
||||
self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False]
|
||||
|
||||
class ControlNetForForgeOfficial(scripts.Script):
|
||||
def title(self):
|
||||
return "ControlNet"
|
||||
|
||||
@@ -205,45 +190,6 @@ class Script(scripts.Script, metaclass=(
|
||||
)
|
||||
return group, group.render(tabname, elem_id_tabname)
|
||||
|
||||
def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str):
|
||||
batch_option = gr.Radio(
|
||||
choices=[e.value for e in external_code.BatchOption],
|
||||
value=external_code.BatchOption.DEFAULT.value,
|
||||
label="Batch Option",
|
||||
elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio",
|
||||
elem_classes="controlnet_batch_option_radio",
|
||||
)
|
||||
use_batch_style_align = gr.Checkbox(
|
||||
label='[StyleAlign] Align image style in the batch.'
|
||||
)
|
||||
|
||||
unit_args = [batch_option, use_batch_style_align]
|
||||
|
||||
def update_ui_batch_options(*args):
|
||||
self.ui_batch_option_state = args
|
||||
return
|
||||
|
||||
for comp in unit_args:
|
||||
event_subscribers = []
|
||||
if hasattr(comp, "edit"):
|
||||
event_subscribers.append(comp.edit)
|
||||
elif hasattr(comp, "click"):
|
||||
event_subscribers.append(comp.click)
|
||||
elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
|
||||
event_subscribers.append(comp.release)
|
||||
elif hasattr(comp, "change"):
|
||||
event_subscribers.append(comp.change)
|
||||
|
||||
if hasattr(comp, "clear"):
|
||||
event_subscribers.append(comp.clear)
|
||||
|
||||
for event_subscriber in event_subscribers:
|
||||
event_subscriber(
|
||||
fn=update_ui_batch_options, inputs=unit_args
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def ui(self, is_img2img):
|
||||
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
|
||||
The return value should be an array of all components that are used in processing.
|
||||
@@ -1012,49 +958,18 @@ class Script(scripts.Script, metaclass=(
|
||||
return
|
||||
|
||||
def process(self, p, *args, **kwargs):
|
||||
for unit in Script.get_enabled_units(p):
|
||||
for unit in self.get_enabled_units(p):
|
||||
self.process_unit_after_click_generate(p, unit, *args, **kwargs)
|
||||
return
|
||||
|
||||
def process_before_every_sampling(self, p, *args, **kwargs):
|
||||
for unit in Script.get_enabled_units(p):
|
||||
for unit in self.get_enabled_units(p):
|
||||
self.process_unit_before_every_sampling(p, unit, *args, **kwargs)
|
||||
return
|
||||
|
||||
def postprocess(self, p, processed, *args):
|
||||
return
|
||||
|
||||
def batch_tab_process(self, p, batches, *args, **kwargs):
|
||||
self.enabled_units = Script.get_enabled_units(p)
|
||||
for unit_i, unit in enumerate(self.enabled_units):
|
||||
unit.batch_images = iter([batch[unit_i] for batch in batches])
|
||||
|
||||
def batch_tab_process_each(self, p, *args, **kwargs):
|
||||
for unit_i, unit in enumerate(self.enabled_units):
|
||||
if getattr(unit, 'loopback', False):
|
||||
continue
|
||||
|
||||
unit.image = next(unit.batch_images)
|
||||
|
||||
def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
|
||||
for unit_i, unit in enumerate(self.enabled_units):
|
||||
if getattr(unit, 'loopback', False):
|
||||
output_images = getattr(processed, 'images', [])[processed.index_of_first_image:]
|
||||
if output_images:
|
||||
unit.image = np.array(output_images[0])
|
||||
else:
|
||||
logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. '
|
||||
f'Using control map from last batch iteration instead')
|
||||
|
||||
def batch_tab_postprocess(self, p, *args, **kwargs):
|
||||
self.enabled_units.clear()
|
||||
self.input_image = None
|
||||
if self.latest_network is None: return
|
||||
|
||||
self.latest_network.restore()
|
||||
self.latest_network = None
|
||||
self.detected_map.clear()
|
||||
|
||||
|
||||
def on_ui_settings():
|
||||
section = ('control_net', "ControlNet")
|
||||
|
||||
Reference in New Issue
Block a user