diff --git a/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py new file mode 100644 index 00000000..9c55240f --- /dev/null +++ b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres.py @@ -0,0 +1,49 @@ +# https://github.com/mcmonkeyprojects/sd-dynamic-thresholding + + +from lib_dynamic_thresholding.dynthres_core import DynThresh + + +class DynamicThresholdingNode: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "mimic_scale": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step": 0.5}), + "threshold_percentile": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "mimic_mode": (DynThresh.Modes, ), + "mimic_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}), + "cfg_mode": (DynThresh.Modes, ), + "cfg_scale_min": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step": 0.5}), + "sched_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), + "separate_feature_channels": (["enable", "disable"], ), + "scaling_startpoint": (DynThresh.Startpoints, ), + "variability_measure": (DynThresh.Variabilities, ), + "interpolate_phi": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + CATEGORY = "advanced/mcmonkey" + + def patch(self, model, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi): + + dynamic_thresh = DynThresh(mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, 0, 999, separate_feature_channels == "enable", scaling_startpoint, variability_measure, interpolate_phi) + + def sampler_dyn_thresh(args): + input = args["input"] + cond = input - args["cond"] + uncond = input - args["uncond"] + cond_scale = args["cond_scale"] + time_step = model.model.model_sampling.timestep(args["sigma"]) + time_step = time_step[0].item() + dynamic_thresh.step = 999 - time_step + + return input - dynamic_thresh.dynthresh(cond, uncond, cond_scale, None) + + m = model.clone() + m.set_model_sampler_cfg_function(sampler_dyn_thresh) + return (m, ) diff --git a/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres_core.py b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres_core.py new file mode 100644 index 00000000..f697441e --- /dev/null +++ b/extensions-builtin/sd_forge_dynamic_thresholding/lib_dynamic_thresholding/dynthres_core.py @@ -0,0 +1,170 @@ +# https://github.com/mcmonkeyprojects/sd-dynamic-thresholding + + +import torch, math + +######################### DynThresh Core ######################### + +class DynThresh: + + Modes = ["Constant", "Linear Down", "Cosine Down", "Half Cosine Down", "Linear Up", "Cosine Up", "Half Cosine Up", "Power Up", "Power Down", "Linear Repeating", "Cosine Repeating", "Sawtooth"] + Startpoints = ["MEAN", "ZERO"] + Variabilities = ["AD", "STD"] + + def __init__(self, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, sched_val, experiment_mode, max_steps, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi): + self.mimic_scale = mimic_scale + self.threshold_percentile = threshold_percentile + self.mimic_mode = mimic_mode + self.cfg_mode = cfg_mode + self.max_steps = max_steps + self.cfg_scale_min = cfg_scale_min + self.mimic_scale_min = mimic_scale_min + self.experiment_mode = experiment_mode + self.sched_val = sched_val + self.sep_feat_channels = separate_feature_channels + self.scaling_startpoint = scaling_startpoint + self.variability_measure = variability_measure + self.interpolate_phi = interpolate_phi + + def interpret_scale(self, scale, mode, min): + scale -= min + max = self.max_steps - 1 + frac = self.step / max + if mode == "Constant": + pass + elif mode == "Linear Down": + scale *= 1.0 - frac + elif mode == "Half Cosine Down": + scale *= math.cos(frac) + elif mode == "Cosine Down": + scale *= math.cos(frac * 1.5707) + elif mode == "Linear Up": + scale *= frac + elif mode == "Half Cosine Up": + scale *= 1.0 - math.cos(frac) + elif mode == "Cosine Up": + scale *= 1.0 - math.cos(frac * 1.5707) + elif mode == "Power Up": + scale *= math.pow(frac, self.sched_val) + elif mode == "Power Down": + scale *= 1.0 - math.pow(frac, self.sched_val) + elif mode == "Linear Repeating": + portion = (frac * self.sched_val) % 1.0 + scale *= (0.5 - portion) * 2 if portion < 0.5 else (portion - 0.5) * 2 + elif mode == "Cosine Repeating": + scale *= math.cos(frac * 6.28318 * self.sched_val) * 0.5 + 0.5 + elif mode == "Sawtooth": + scale *= (frac * self.sched_val) % 1.0 + scale += min + return scale + + def dynthresh(self, cond, uncond, cfg_scale, weights): + mimic_scale = self.interpret_scale(self.mimic_scale, self.mimic_mode, self.mimic_scale_min) + cfg_scale = self.interpret_scale(cfg_scale, self.cfg_mode, self.cfg_scale_min) + # uncond shape is (batch, 4, height, width) + conds_per_batch = cond.shape[0] / uncond.shape[0] + assert conds_per_batch == int(conds_per_batch), "Expected # of conds per batch to be constant across batches" + cond_stacked = cond.reshape((-1, int(conds_per_batch)) + uncond.shape[1:]) + + ### Normal first part of the CFG Scale logic, basically + diff = cond_stacked - uncond.unsqueeze(1) + if weights is not None: + diff = diff * weights + relative = diff.sum(1) + + ### Get the normal result for both mimic and normal scale + mim_target = uncond + relative * mimic_scale + cfg_target = uncond + relative * cfg_scale + ### If we weren't doing mimic scale, we'd just return cfg_target here + + ### Now recenter the values relative to their average rather than absolute, to allow scaling from average + mim_flattened = mim_target.flatten(2) + cfg_flattened = cfg_target.flatten(2) + mim_means = mim_flattened.mean(dim=2).unsqueeze(2) + cfg_means = cfg_flattened.mean(dim=2).unsqueeze(2) + mim_centered = mim_flattened - mim_means + cfg_centered = cfg_flattened - cfg_means + + if self.sep_feat_channels: + if self.variability_measure == 'STD': + mim_scaleref = mim_centered.std(dim=2).unsqueeze(2) + cfg_scaleref = cfg_centered.std(dim=2).unsqueeze(2) + else: # 'AD' + mim_scaleref = mim_centered.abs().max(dim=2).values.unsqueeze(2) + cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile, dim=2).unsqueeze(2) + + else: + if self.variability_measure == 'STD': + mim_scaleref = mim_centered.std() + cfg_scaleref = cfg_centered.std() + else: # 'AD' + mim_scaleref = mim_centered.abs().max() + cfg_scaleref = torch.quantile(cfg_centered.abs(), self.threshold_percentile) + + if self.scaling_startpoint == 'ZERO': + scaling_factor = mim_scaleref / cfg_scaleref + result = cfg_flattened * scaling_factor + + else: # 'MEAN' + if self.variability_measure == 'STD': + cfg_renormalized = (cfg_centered / cfg_scaleref) * mim_scaleref + else: # 'AD' + ### Get the maximum value of all datapoints (with an optional threshold percentile on the uncond) + max_scaleref = torch.maximum(mim_scaleref, cfg_scaleref) + ### Clamp to the max + cfg_clamped = cfg_centered.clamp(-max_scaleref, max_scaleref) + ### Now shrink from the max to normalize and grow to the mimic scale (instead of the CFG scale) + cfg_renormalized = (cfg_clamped / max_scaleref) * mim_scaleref + + ### Now add it back onto the averages to get into real scale again and return + result = cfg_renormalized + cfg_means + + actual_res = result.unflatten(2, mim_target.shape[2:]) + + if self.interpolate_phi != 1.0: + actual_res = actual_res * self.interpolate_phi + cfg_target * (1.0 - self.interpolate_phi) + + if self.experiment_mode == 1: + num = actual_res.cpu().numpy() + for y in range(0, 64): + for x in range (0, 64): + if num[0][0][y][x] > 1.0: + num[0][1][y][x] *= 0.5 + if num[0][1][y][x] > 1.0: + num[0][1][y][x] *= 0.5 + if num[0][2][y][x] > 1.5: + num[0][2][y][x] *= 0.5 + actual_res = torch.from_numpy(num).to(device=uncond.device) + elif self.experiment_mode == 2: + num = actual_res.cpu().numpy() + for y in range(0, 64): + for x in range (0, 64): + over_scale = False + for z in range(0, 4): + if abs(num[0][z][y][x]) > 1.5: + over_scale = True + if over_scale: + for z in range(0, 4): + num[0][z][y][x] *= 0.7 + actual_res = torch.from_numpy(num).to(device=uncond.device) + elif self.experiment_mode == 3: + coefs = torch.tensor([ + # R G B W + [0.298, 0.207, 0.208, 0.0], # L1 + [0.187, 0.286, 0.173, 0.0], # L2 + [-0.158, 0.189, 0.264, 0.0], # L3 + [-0.184, -0.271, -0.473, 1.0], # L4 + ], device=uncond.device) + res_rgb = torch.einsum("laxy,ab -> lbxy", actual_res, coefs) + max_r, max_g, max_b, max_w = res_rgb[0][0].max(), res_rgb[0][1].max(), res_rgb[0][2].max(), res_rgb[0][3].max() + max_rgb = max(max_r, max_g, max_b) + print(f"test max = r={max_r}, g={max_g}, b={max_b}, w={max_w}, rgb={max_rgb}") + if self.step / (self.max_steps - 1) > 0.2: + if max_rgb < 2.0 and max_w < 3.0: + res_rgb /= max_rgb / 2.4 + else: + if max_rgb > 2.4 and max_w > 3.0: + res_rgb /= max_rgb / 2.4 + actual_res = torch.einsum("laxy,ab -> lbxy", res_rgb, coefs.inverse()) + + return actual_res diff --git a/extensions-builtin/sd_forge_dynamic_thresholding/scripts/forge_dynamic_thresholding.py b/extensions-builtin/sd_forge_dynamic_thresholding/scripts/forge_dynamic_thresholding.py new file mode 100644 index 00000000..686ffe47 --- /dev/null +++ b/extensions-builtin/sd_forge_dynamic_thresholding/scripts/forge_dynamic_thresholding.py @@ -0,0 +1,79 @@ +import gradio as gr + +from modules import scripts +from lib_dynamic_thresholding.dynthres import DynamicThresholdingNode + +opDynamicThresholdingNode = DynamicThresholdingNode().patch + + +class DynamicThresholdingForForge(scripts.Script): + def title(self): + return "DynamicThresholding (CFG-Fix) Integrated" + + def show(self, is_img2img): + # make this extension visible in both txt2img and img2img tab. + return scripts.AlwaysVisible + + def ui(self, *args, **kwargs): + with gr.Accordion(open=False, label=self.title()): + enabled = gr.Checkbox(label='Enabled', value=False) + mimic_scale = gr.Slider(label='Mimic Scale', minimum=0.0, maximum=100.0, step=0.5, value=7.0) + threshold_percentile = gr.Slider(label='Threshold Percentile', minimum=0.0, maximum=1.0, step=0.01, + value=1.0) + mimic_mode = gr.Radio(label='Mimic Mode', + choices=['Constant', 'Linear Down', 'Cosine Down', 'Half Cosine Down', 'Linear Up', + 'Cosine Up', 'Half Cosine Up', 'Power Up', 'Power Down', 'Linear Repeating', + 'Cosine Repeating', 'Sawtooth'], value='Constant') + mimic_scale_min = gr.Slider(label='Mimic Scale Min', minimum=0.0, maximum=100.0, step=0.5, value=0.0) + cfg_mode = gr.Radio(label='Cfg Mode', + choices=['Constant', 'Linear Down', 'Cosine Down', 'Half Cosine Down', 'Linear Up', + 'Cosine Up', 'Half Cosine Up', 'Power Up', 'Power Down', 'Linear Repeating', + 'Cosine Repeating', 'Sawtooth'], value='Constant') + cfg_scale_min = gr.Slider(label='Cfg Scale Min', minimum=0.0, maximum=100.0, step=0.5, value=0.0) + sched_val = gr.Slider(label='Sched Val', minimum=0.0, maximum=100.0, step=0.01, value=1.0) + separate_feature_channels = gr.Radio(label='Separate Feature Channels', choices=['enable', 'disable'], + value='enable') + scaling_startpoint = gr.Radio(label='Scaling Startpoint', choices=['MEAN', 'ZERO'], value='MEAN') + variability_measure = gr.Radio(label='Variability Measure', choices=['AD', 'STD'], value='AD') + interpolate_phi = gr.Slider(label='Interpolate Phi', minimum=0.0, maximum=1.0, step=0.01, value=1.0) + + return enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, \ + sched_val, separate_feature_channels, scaling_startpoint, variability_measure, interpolate_phi + + def process_before_every_sampling(self, p, *script_args, **kwargs): + # This will be called before every sampling. + # If you use highres fix, this will be called twice. + + enabled, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, cfg_mode, cfg_scale_min, \ + sched_val, separate_feature_channels, scaling_startpoint, variability_measure, \ + interpolate_phi = script_args + + if not enabled: + return + + unet = p.sd_model.forge_objects.unet + + unet = opDynamicThresholdingNode(unet, mimic_scale, threshold_percentile, mimic_mode, mimic_scale_min, + cfg_mode, cfg_scale_min, sched_val, separate_feature_channels, + scaling_startpoint, variability_measure, interpolate_phi)[0] + + p.sd_model.forge_objects.unet = unet + + # Below codes will add some logs to the texts below the image outputs on UI. + # The extra_generation_params does not influence results. + p.extra_generation_params.update(dict( + dynthres_enabled=enabled, + dynthres_mimic_scale=mimic_scale, + dynthres_threshold_percentile=threshold_percentile, + dynthres_mimic_mode=mimic_mode, + dynthres_mimic_scale_min=mimic_scale_min, + dynthres_cfg_mode=cfg_mode, + dynthres_cfg_scale_min=cfg_scale_min, + dynthres_sched_val=sched_val, + dynthres_separate_feature_channels=separate_feature_channels, + dynthres_scaling_startpoint=scaling_startpoint, + dynthres_variability_measure=variability_measure, + dynthres_interpolate_phi=interpolate_phi, + )) + + return