From 57a294b111cc1b73596180bd6e52aad1f809b6ab Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 27 Jan 2024 13:21:25 -0800 Subject: [PATCH] better sampler --- modules/sd_samplers_cfg_denoiser.py | 171 +--------------------------- modules_forge/forge_loader.py | 24 +--- modules_forge/forge_sampler.py | 49 ++++++++ modules_forge/forge_util.py | 27 ----- 4 files changed, 55 insertions(+), 216 deletions(-) create mode 100644 modules_forge/forge_sampler.py diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 8aa26982..9e1ae2a6 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,7 +6,7 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback -from ldm_patched.modules import model_management +from modules_forge import forge_sampler def catenate_conds(conds): @@ -76,41 +76,7 @@ class CFGDenoiser(torch.nn.Module): for cond_index, weight in conds: denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - if "sampler_cfg_function" in model_options or "sampler_post_cfg_function" in model_options: - cond_scale = float(cond_scale) - model = self.inner_model.inner_model.forge_objects.unet.model - x = x_in[-uncond.shape[0]:] - uncond_pred = denoised_uncond - cond_pred = ((denoised - uncond_pred) / cond_scale) + uncond_pred - timestep = timestep[-uncond.shape[0]:] - - from modules_forge.forge_util import cond_from_a1111_to_patched_ldm - - if "sampler_cfg_function" in model_options: - args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, - "timestep": timestep, "input": x, "sigma": timestep, - "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, - "model_options": model_options} - cfg_result = x - model_options["sampler_cfg_function"](args) - else: - cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale - # sanity_check = torch.allclose(cfg_result, denoised) - - for fn in model_options.get("sampler_post_cfg_function", []): - args = {"denoised": cfg_result, - "cond": cond_from_a1111_to_patched_ldm(cond), - "uncond": cond_from_a1111_to_patched_ldm(uncond), - "model": model, - "uncond_denoised": uncond_pred, - "cond_denoised": cond_pred, - "sigma": timestep, - "model_options": model_options, - "input": x} - cfg_result = fn(args) - else: - cfg_result = denoised - - return cfg_result + return denoised def combine_denoised_for_edit_model(self, x_out, cond_scale): out_cond, out_img_cond, out_uncond = x_out.chunk(3) @@ -161,138 +127,11 @@ class CFGDenoiser(torch.nn.Module): if self.mask_before_denoising and self.mask is not None: x = apply_blend(x) - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): - make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} - else: - make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} - - if not is_edit_model: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond, self) + denoiser_params = CFGDenoiserParams(x, image_cond, sigma, state.sampling_step, state.sampling_steps, tensor, uncond, self) cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - tensor = denoiser_params.text_cond - uncond = denoiser_params.text_uncond - skip_uncond = False - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: - skip_uncond = True - x_in = x_in[:-batch_size] - sigma_in = sigma_in[:-batch_size] - - self.padded_cond_uncond = False - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - - if num_repeats < 0: - tensor = pad_cond(tensor, -num_repeats, empty) - self.padded_cond_uncond = True - elif num_repeats > 0: - uncond = pad_cond(uncond, num_repeats, empty) - self.padded_cond_uncond = True - - unet_input_dtype = torch.float16 if model_management.should_use_fp16() else torch.float32 - x_input_dtype = x_in.dtype - - x_in = x_in.to(unet_input_dtype) - sigma_in = sigma_in.to(unet_input_dtype) - image_cond_in = image_cond_in.to(unet_input_dtype) - tensor = tensor.to(unet_input_dtype) - uncond = uncond.to(unet_input_dtype) - - self.inner_model.inner_model.current_sigmas = sigma_in - - if tensor.shape[1] == uncond.shape[1] or skip_uncond: - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0]) + [1] * int(uncond.shape[0]) - elif skip_uncond: - cond_in = tensor - cond_or_uncond = [0] * int(tensor.shape[0]) - else: - cond_in = catenate_conds([tensor, uncond]) - cond_or_uncond = [0] * int(tensor.shape[0]) + [1] * int(uncond.shape[0]) - - if shared.opts.batch_cond_uncond: - self.inner_model.inner_model.cond_or_uncond = cond_or_uncond - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - self.inner_model.inner_model.cond_or_uncond = cond_or_uncond[a:b] - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - - self.inner_model.inner_model.cond_or_uncond = [0] * int(sigma_in[a:b].shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - self.inner_model.inner_model.cond_or_uncond = [1] * int(sigma_in[-uncond.shape[0]:].shape[0]) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - x_out = x_out.to(x_input_dtype) - - denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) - cfg_denoised_callback(denoised_params) - - devices.test_for_nans(x_out, "unet") - - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0, sigma_in, x_in, tensor) - else: - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, sigma_in, x_in, tensor) - - # Blend in the original latents (after) - if not self.mask_before_denoising and self.mask is not None: - denoised = apply_blend(denoised) - - self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) - - if opts.live_preview_content == "Prompt": - preview = self.sampler.last_latent - elif opts.live_preview_content == "Negative prompt": - preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) - else: - preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) + denoised = forge_sampler.forge_sample(self, denoiser_params=denoiser_params, cond_scale=cond_scale) + preview = self.sampler.last_latent = denoised sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index e68793fc..3da2ba4e 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -240,30 +240,8 @@ def load_model_for_a1111(timer, checkpoint_info=None, state_dict=None): sd_model.get_first_stage_encoding = lambda x: x sd_model.decode_first_stage = patched_decode_first_stage sd_model.encode_first_stage = patched_encode_first_stage - - sd_model.current_controlnet_signals = { - 'input': [], - 'middle': [], - 'output': [] - } - sd_model.current_controlnet_required_memory = 0 - - original_forward = sd_model.model.diffusion_model.forward - - def forge_unet_forward(*args, **kwargs): - current_transformer_options = kwargs.get('transformer_options', {}) - current_transformer_options.update(dict(cond_or_uncond=sd_model.cond_or_uncond, sigmas=sd_model.current_sigmas)) - current_transformer_options.update(sd_model.forge_objects.unet.model_options.get('transformer_options', {})) - - kwargs.update(dict( - control=sd_model.current_controlnet_signals, - transformer_options=current_transformer_options - )) - return original_forward(*args, **kwargs) - - sd_model.model.diffusion_model.forward = forge_unet_forward - sd_model.clip = sd_model.cond_stage_model + sd_model.current_controlnet_required_memory = 0 timer.record("forge finalize") sd_model.current_lora_hash = str([]) diff --git a/modules_forge/forge_sampler.py b/modules_forge/forge_sampler.py new file mode 100644 index 00000000..179c03fc --- /dev/null +++ b/modules_forge/forge_sampler.py @@ -0,0 +1,49 @@ +import torch +from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn +from ldm_patched.modules.samplers import sampling_function + + +def cond_from_a1111_to_patched_ldm(cond): + if isinstance(cond, torch.Tensor): + result = dict( + cross_attn=cond, + model_conds=dict( + c_crossattn=CONDCrossAttn(cond), + ) + ) + return [result, ] + + cross_attn = cond['crossattn'] + pooled_output = cond['vector'] + + result = dict( + cross_attn=cross_attn, + pooled_output=pooled_output, + model_conds=dict( + c_crossattn=CONDCrossAttn(cross_attn), + y=CONDRegular(pooled_output) + ) + ) + + return [result, ] + + +def forge_sample(self, denoiser_params, cond_scale): + model = self.inner_model.inner_model.forge_objects.unet.model + x = denoiser_params.x + timestep = denoiser_params.sigma + uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) + cond = cond_from_a1111_to_patched_ldm(denoiser_params.text_cond) + model_options = self.inner_model.inner_model.forge_objects.unet.model_options + seed = self.p.seeds[0] + + image_cond_in = denoiser_params.image_cond + if isinstance(image_cond_in, torch.Tensor): + if image_cond_in.shape[0] == x.shape[0] \ + and image_cond_in.shape[2] == x.shape[2] \ + and image_cond_in.shape[3] == x.shape[3]: + uncond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in) + cond[0]['model_conds']['c_concat'] = CONDRegular(image_cond_in) + + denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) + return denoised diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index 30fcc373..b1afd9ab 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -5,8 +5,6 @@ import time import random import string -from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn - def generate_random_filename(extension=".txt"): timestamp = time.strftime("%Y%m%d-%H%M%S") @@ -15,31 +13,6 @@ def generate_random_filename(extension=".txt"): return filename -def cond_from_a1111_to_patched_ldm(cond): - if isinstance(cond, torch.Tensor): - result = dict( - cross_attn=cond, - model_conds=dict( - c_crossattn=CONDCrossAttn(cond), - ) - ) - return [result, ] - - cross_attn = cond['crossattn'] - pooled_output = cond['vector'] - - result = dict( - cross_attn=cross_attn, - pooled_output=pooled_output, - model_conds=dict( - c_crossattn=CONDCrossAttn(cross_attn), - y=CONDRegular(pooled_output) - ) - ) - - return [result, ] - - @torch.no_grad() @torch.inference_mode() def pytorch_to_numpy(x):