From 848f28af6e94d632a3ed5ae4c5138b7ba814e1ee Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 19:37:33 -0800 Subject: [PATCH] i --- .../scripts/sd_forge_controlnet_example.py | 25 +++++++++++++------ modules/processing.py | 22 ++++++++++++++++ modules/scripts.py | 16 ++++++++++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py index be695ac1..c7ea226b 100644 --- a/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py +++ b/extensions-builtin/sd_forge_controlnet_example/scripts/sd_forge_controlnet_example.py @@ -29,9 +29,7 @@ class ControlNetExampleForge(scripts.Script): return input_image, funny_slider - def process_batch(self, p, *script_args, **kwargs): - # This function will be called every batch. Use your own way to cache. - + def process(self, p, *script_args, **kwargs): input_image, funny_slider = script_args # This slider does nothing. It just shows you how to transfer parameters. @@ -40,8 +38,6 @@ class ControlNetExampleForge(scripts.Script): if input_image is None: return - print('Input image is read.') - model_dir = os.path.join(models_path, 'ControlNet') os.makedirs(model_dir, exist_ok=True) controlnet_canny_path = load_file_from_url( @@ -54,10 +50,25 @@ class ControlNetExampleForge(scripts.Script): controlnet = load_controlnet(controlnet_canny_path) print('Controlnet loaded.') - input_image = cv2.resize(input_image, (p.height, p.width)) + return + + def process_before_every_sampling(self, p, *script_args, **kwargs): + # This will be called before every sampling. + # If you use highres fix, this will be called twice. + + input_image, funny_slider = script_args + + if input_image is None: + return + + B, C, H, W = kwargs['noise'].shape # latent_shape + height = H * 8 + width = W * 8 + + input_image = cv2.resize(input_image, (height, width)) canny_image = cv2.Canny(input_image, 100, 200) - # Display preprocessor result. Called every batch. Cache in your own way. + # Display preprocessor result. Called every sampling. Cache in your own way. p.extra_result_images.append(canny_image) print('Preprocessor Canny finished.') diff --git a/modules/processing.py b/modules/processing.py index 7a09105a..c25ab426 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1247,6 +1247,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): # here we generate an image normally x = self.rng.next() + + if self.scripts is not None: + self.scripts.process_before_every_sampling(self, + x=x, + noise=x, + c=conditioning, + uc=unconditional_conditioning) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1348,6 +1356,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.scripts is not None: self.scripts.before_hr(self) + if self.scripts is not None: + self.scripts.process_before_every_sampling(self, + x=samples, + noise=noise, + c=self.hr_c, + uc=self.hr_uc) + samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) @@ -1651,6 +1666,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier x *= self.initial_noise_multiplier + if self.scripts is not None: + self.scripts.process_before_every_sampling(self, + x=self.init_latent, + noise=x, + c=conditioning, + uc=unconditional_conditioning) + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: diff --git a/modules/scripts.py b/modules/scripts.py index 94690a22..f490ee5d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -186,6 +186,14 @@ class Script: """ pass + def process_before_every_sampling(self, p, *args, **kwargs): + """ + Similar to process(), called before every sampling. + If you use high-res fix, this will be called two times. + """ + + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -809,6 +817,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running process_batch: {script.filename}", exc_info=True) + def process_before_every_sampling(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.process_before_every_sampling(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running process_before_every_sampling: {script.filename}", exc_info=True) + def postprocess(self, p, processed): for script in self.alwayson_scripts: try: