Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 14:45:44 -08:00
parent ec7adb41fa
commit 05d937dcd2

View File

@@ -166,22 +166,7 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:
return y
class Script(scripts.Script, metaclass=(
utils.TimeMeta if logger.level == logging.DEBUG else type)):
model_cache = OrderedDict()
def __init__(self) -> None:
super().__init__()
self.latest_network = None
self.input_image = None
self.latest_model_hash = ""
self.enabled_units = []
self.detected_map = []
self.post_processors = []
self.noise_modifier = None
self.ui_batch_option_state = [external_code.BatchOption.DEFAULT.value, False]
class ControlNetForForgeOfficial(scripts.Script):
def title(self):
return "ControlNet"
@@ -205,45 +190,6 @@ class Script(scripts.Script, metaclass=(
)
return group, group.render(tabname, elem_id_tabname)
def ui_batch_options(self, is_img2img: bool, elem_id_tabname: str):
batch_option = gr.Radio(
choices=[e.value for e in external_code.BatchOption],
value=external_code.BatchOption.DEFAULT.value,
label="Batch Option",
elem_id=f"{elem_id_tabname}_controlnet_batch_option_radio",
elem_classes="controlnet_batch_option_radio",
)
use_batch_style_align = gr.Checkbox(
label='[StyleAlign] Align image style in the batch.'
)
unit_args = [batch_option, use_batch_style_align]
def update_ui_batch_options(*args):
self.ui_batch_option_state = args
return
for comp in unit_args:
event_subscribers = []
if hasattr(comp, "edit"):
event_subscribers.append(comp.edit)
elif hasattr(comp, "click"):
event_subscribers.append(comp.click)
elif isinstance(comp, gr.Slider) and hasattr(comp, "release"):
event_subscribers.append(comp.release)
elif hasattr(comp, "change"):
event_subscribers.append(comp.change)
if hasattr(comp, "clear"):
event_subscribers.append(comp.clear)
for event_subscriber in event_subscribers:
event_subscriber(
fn=update_ui_batch_options, inputs=unit_args
)
return
def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
@@ -1012,49 +958,18 @@ class Script(scripts.Script, metaclass=(
return
def process(self, p, *args, **kwargs):
for unit in Script.get_enabled_units(p):
for unit in self.get_enabled_units(p):
self.process_unit_after_click_generate(p, unit, *args, **kwargs)
return
def process_before_every_sampling(self, p, *args, **kwargs):
for unit in Script.get_enabled_units(p):
for unit in self.get_enabled_units(p):
self.process_unit_before_every_sampling(p, unit, *args, **kwargs)
return
def postprocess(self, p, processed, *args):
return
def batch_tab_process(self, p, batches, *args, **kwargs):
self.enabled_units = Script.get_enabled_units(p)
for unit_i, unit in enumerate(self.enabled_units):
unit.batch_images = iter([batch[unit_i] for batch in batches])
def batch_tab_process_each(self, p, *args, **kwargs):
for unit_i, unit in enumerate(self.enabled_units):
if getattr(unit, 'loopback', False):
continue
unit.image = next(unit.batch_images)
def batch_tab_postprocess_each(self, p, processed, *args, **kwargs):
for unit_i, unit in enumerate(self.enabled_units):
if getattr(unit, 'loopback', False):
output_images = getattr(processed, 'images', [])[processed.index_of_first_image:]
if output_images:
unit.image = np.array(output_images[0])
else:
logger.warning(f'Warning: No loopback image found for controlnet unit {unit_i}. '
f'Using control map from last batch iteration instead')
def batch_tab_postprocess(self, p, *args, **kwargs):
self.enabled_units.clear()
self.input_image = None
if self.latest_network is None: return
self.latest_network.restore()
self.latest_network = None
self.detected_map.clear()
def on_ui_settings():
section = ('control_net', "ControlNet")