diff --git a/backend/sampling/condition.py b/backend/sampling/condition.py index ff6eb7a2..b594d1b8 100644 --- a/backend/sampling/condition.py +++ b/backend/sampling/condition.py @@ -89,6 +89,9 @@ class ConditionConstant(Condition): def compile_conditions(cond): + if cond is None: + return None + if isinstance(cond, torch.Tensor): result = dict( cross_attn=cond, diff --git a/modules/processing.py b/modules/processing.py index d35a8303..85f7737b 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -483,7 +483,12 @@ class StableDiffusionProcessing: self.step_multiplier = total_steps // self.steps self.firstpass_steps = total_steps - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) + if self.cfg_scale == 1: + self.uc = None + print('Skipping unconditional conditioning when CFG = 1') + else: + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) def get_conds(self): diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b329a6c9..d7146600 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -173,7 +173,7 @@ class CFGDenoiser(torch.nn.Module): uncond = self.sampler.sampler_extra_args['uncond'] cond_composition, cond = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) if uncond is not None else None if self.mask is not None: noisy_initial_latent = self.init_latent + sigma[:, None, None, None] * torch.randn_like(self.init_latent).to(self.init_latent)