Update controlnet.py

This commit is contained in:
lllyasviel
2024-01-29 15:03:12 -08:00
parent d67c4db227
commit 5b58c21464

View File

@@ -166,6 +166,12 @@ def get_pytorch_control(x: np.ndarray) -> torch.Tensor:
return y
class ControlNetCachedParameters:
def __init__(self):
self.control_image = None
self.control_image_for_hr_fix = None
class ControlNetForForgeOfficial(scripts.Script):
def title(self):
return "ControlNet"
@@ -941,20 +947,23 @@ class ControlNetForForgeOfficial(scripts.Script):
self.detected_map = detected_maps
self.post_processors = post_processors
def process_unit_after_click_generate(self, p, unit, *args, **kwargs):
def process_unit_after_click_generate(self, p, unit, params, *args, **kwargs):
return
def process_unit_before_every_sampling(self, p, unit, *args, **kwargs):
def process_unit_before_every_sampling(self, p, unit, params, *args, **kwargs):
return
def process(self, p, *args, **kwargs):
for unit in self.get_enabled_units(p):
self.process_unit_after_click_generate(p, unit, *args, **kwargs)
self.current_params = {}
for i, unit in enumerate(self.get_enabled_units(p)):
params = ControlNetCachedParameters()
self.process_unit_after_click_generate(p, unit, params, *args, **kwargs)
self.current_params[i] = params
return
def process_before_every_sampling(self, p, *args, **kwargs):
for unit in self.get_enabled_units(p):
self.process_unit_before_every_sampling(p, unit, *args, **kwargs)
for unit, params in zip(self.get_enabled_units(p), self.current_params):
self.process_unit_before_every_sampling(p, unit, params, *args, **kwargs)
return
def postprocess(self, p, processed, *args):