From 4915593e72e7bab36d1c54470a1336cf18327aa0 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Tue, 30 Jan 2024 18:47:18 -0800 Subject: [PATCH] Update forge_reference.py --- .../scripts/forge_reference.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py index 764d8c09..a870b36b 100644 --- a/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py +++ b/extensions-builtin/forge_preprocessor_reference/scripts/forge_reference.py @@ -2,6 +2,7 @@ import torch from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter from modules_forge.shared import add_supported_preprocessor +from ldm_patched.modules.samplers import sampling_function class PreprocessorReference(Preprocessor): @@ -33,6 +34,9 @@ class PreprocessorReference(Preprocessor): latent_image = vae.encode(cond.movedim(1, -1)) latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) + gen_seed = process.seeds[0] + 1 + gen_cpu = torch.Generator().manual_seed(gen_seed) + unet = process.sd_model.forge_objects.unet.clone() sigma_max = unet.model.model_sampling.percent_to_sigma(start_percent) sigma_min = unet.model.model_sampling.percent_to_sigma(end_percent) @@ -44,6 +48,9 @@ class PreprocessorReference(Preprocessor): self.is_recording_style = True + xt = latent_image.to(x) + torch.randn(x.size(), dtype=x.dtype, generator=gen_cpu).to(x) * sigma + sampling_function(model, xt, timestep, uncond, cond, 1, model_options, seed) + self.is_recording_style = False return model, x, timestep, uncond, cond, cond_scale, model_options, seed @@ -53,7 +60,10 @@ class PreprocessorReference(Preprocessor): if not (sigma_min <= sigma <= sigma_max): return h - a = 0 + if self.is_recording_style: + a = 0 + else: + b = 0 return h @@ -62,7 +72,10 @@ class PreprocessorReference(Preprocessor): if not (sigma_min <= sigma <= sigma_max): return q, k, v - a = 0 + if self.is_recording_style: + a = 0 + else: + b = 0 return q, k, v @@ -71,7 +84,10 @@ class PreprocessorReference(Preprocessor): if not (sigma_min <= sigma <= sigma_max): return h - a = 0 + if self.is_recording_style: + a = 0 + else: + b = 0 return h