From 998327c744d83a9327289c5639afb603cc2d3464 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sat, 10 Feb 2024 03:24:40 -0800 Subject: [PATCH] add edit model --- ldm_patched/modules/samplers.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 0c2e030d..03a6329d 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -272,6 +272,8 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #The main sampling function shared by all the samplers #Returns denoised def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): + edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond) + if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None else: @@ -281,10 +283,13 @@ def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_option model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options) cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) + if "sampler_cfg_function" in model_options: args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} cfg_result = x - model_options["sampler_cfg_function"](args) + elif not math.isclose(edit_strength, 1.0): + cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale * edit_strength else: cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale