From 6ed95d66e55d0489ab476a1573e1efb8c8b13a17 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Fri, 2 Aug 2024 14:56:44 -0700 Subject: [PATCH] rework several component patcher backend is 65% finished --- backend/modules/clip.py | 14 -- backend/patcher/clip.py | 42 ++++ backend/patcher/unet.py | 190 +++++++++++++++ backend/patcher/vae.py | 223 ++++++++++++++++++ .../scripts/forge_perturbed_attention.py | 4 +- ldm_patched/modules/sd.py | 3 - modules_forge/forge_loader.py | 4 +- modules_forge/unet_patcher.py | 216 +---------------- 8 files changed, 461 insertions(+), 235 deletions(-) delete mode 100644 backend/modules/clip.py create mode 100644 backend/patcher/clip.py create mode 100644 backend/patcher/unet.py create mode 100644 backend/patcher/vae.py diff --git a/backend/modules/clip.py b/backend/modules/clip.py deleted file mode 100644 index 7adbded3..00000000 --- a/backend/modules/clip.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch - - -class JointTokenizer: - def __init__(self, huggingface_components): - self.clip_l = huggingface_components.get('tokenizer', None) - self.clip_g = huggingface_components.get('tokenizer_2', None) - - -class JointCLIP(torch.nn.Module): - def __init__(self, huggingface_components): - super().__init__() - self.clip_l = huggingface_components.get('text_encoder', None) - self.clip_g = huggingface_components.get('text_encoder_2', None) diff --git a/backend/patcher/clip.py b/backend/patcher/clip.py new file mode 100644 index 00000000..260b91c1 --- /dev/null +++ b/backend/patcher/clip.py @@ -0,0 +1,42 @@ +import torch + +from backend import memory_management +from backend.patcher.base import ModelPatcher + + +class JointTokenizer: + def __init__(self, huggingface_components): + self.clip_l = huggingface_components.get('tokenizer', None) + self.clip_g = huggingface_components.get('tokenizer_2', None) + + +class JointCLIPTextEncoder(torch.nn.Module): + def __init__(self, huggingface_components): + super().__init__() + self.clip_l = huggingface_components.get('text_encoder', None) + self.clip_g = huggingface_components.get('text_encoder_2', None) + + +class CLIP: + def __init__(self, huggingface_components=None, no_init=False): + if no_init: + return + + load_device = memory_management.text_encoder_device() + offload_device = memory_management.text_encoder_offload_device() + text_encoder_dtype = memory_management.text_encoder_dtype(load_device) + + self.cond_stage_model = JointCLIPTextEncoder(huggingface_components) + self.tokenizer = JointTokenizer(huggingface_components) + self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device) + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + + def clone(self): + n = CLIP(no_init=True) + n.patcher = self.patcher.clone() + n.cond_stage_model = self.cond_stage_model + n.tokenizer = self.tokenizer + return n + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + return self.patcher.add_patches(patches, strength_patch, strength_model) diff --git a/backend/patcher/unet.py b/backend/patcher/unet.py new file mode 100644 index 00000000..38528441 --- /dev/null +++ b/backend/patcher/unet.py @@ -0,0 +1,190 @@ +import copy +import torch + +from backend.patcher.base import ModelPatcher + + +class UnetPatcher(ModelPatcher): + def __init__(self, model, *args, **kwargs): + super().__init__(model, *args, **kwargs) + self.controlnet_linked_list = None + self.extra_preserved_memory_during_sampling = 0 + self.extra_model_patchers_during_sampling = [] + self.extra_concat_condition = None + + def clone(self): + n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, + weight_inplace_update=self.weight_inplace_update) + + n.patches = {} + for k in self.patches: + n.patches[k] = self.patches[k][:] + + n.object_patches = self.object_patches.copy() + n.model_options = copy.deepcopy(self.model_options) + n.controlnet_linked_list = self.controlnet_linked_list + n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling + n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy() + n.extra_concat_condition = self.extra_concat_condition + return n + + def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int): + # Use this to ask Forge to preserve a certain amount of memory during sampling. + # If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024 + # Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM. + # You can estimate this using memory_management.module_size(any_pytorch_model) to get size of any pytorch models. + self.extra_preserved_memory_during_sampling += memory_in_bytes + return + + def add_extra_model_patcher_during_sampling(self, model_patcher: ModelPatcher): + # Use this to ask Forge to move extra model patchers to GPU during sampling. + # This method will manage GPU memory perfectly. + self.extra_model_patchers_during_sampling.append(model_patcher) + return + + def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True): + # Use this method to bind an extra torch.nn.Module to this UNet during sampling. + # This model `m` will be delegated to Forge memory management system. + # `m` will be loaded to GPU everytime when sampling starts. + # `m` will be unloaded if necessary. + # `m` will influence Forge's judgement about use GPU memory or + # capacity and decide whether to use module offload to make user's batch size larger. + # Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling. + + if cast_to_unet_dtype: + m.to(self.model.diffusion_model.dtype) + + patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device) + + self.add_extra_model_patcher_during_sampling(patcher) + return patcher + + def add_patched_controlnet(self, cnet): + cnet.set_previous_controlnet(self.controlnet_linked_list) + self.controlnet_linked_list = cnet + return + + def list_controlnets(self): + results = [] + pointer = self.controlnet_linked_list + while pointer is not None: + results.append(pointer) + pointer = pointer.previous_controlnet + return results + + def append_model_option(self, k, v, ensure_uniqueness=False): + if k not in self.model_options: + self.model_options[k] = [] + + if ensure_uniqueness and v in self.model_options[k]: + return + + self.model_options[k].append(v) + return + + def append_transformer_option(self, k, v, ensure_uniqueness=False): + if 'transformer_options' not in self.model_options: + self.model_options['transformer_options'] = {} + + to = self.model_options['transformer_options'] + + if k not in to: + to[k] = [] + + if ensure_uniqueness and v in to[k]: + return + + to[k].append(v) + return + + def set_transformer_option(self, k, v): + if 'transformer_options' not in self.model_options: + self.model_options['transformer_options'] = {} + + self.model_options['transformer_options'][k] = v + return + + def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): + self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) + return + + def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False): + self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness) + return + + def set_memory_peak_estimation_modifier(self, modifier): + self.model_options['memory_peak_estimation_modifier'] = modifier + return + + def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False): + """ + + For some reasons, this function only works in A1111's Script.process_batch(self, p, *args, **kwargs) + + For example, below is a worked modification: + + class ExampleScript(scripts.Script): + + def process_batch(self, p, *args, **kwargs): + unet = p.sd_model.forge_objects.unet.clone() + + def modifier(x): + return x ** 0.5 + + unet.add_alphas_cumprod_modifier(modifier) + p.sd_model.forge_objects.unet = unet + + return + + This add_alphas_cumprod_modifier is the only patch option that should be used in process_batch() + All other patch options should be called in process_before_every_sampling() + + """ + + self.append_model_option('alphas_cumprod_modifiers', modifier, ensure_uniqueness) + return + + def add_block_modifier(self, modifier, ensure_uniqueness=False): + self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness) + return + + def add_block_inner_modifier(self, modifier, ensure_uniqueness=False): + self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness) + return + + def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False): + self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness) + return + + def set_groupnorm_wrapper(self, wrapper): + self.set_transformer_option('groupnorm_wrapper', wrapper) + return + + def set_controlnet_model_function_wrapper(self, wrapper): + self.set_transformer_option('controlnet_model_function_wrapper', wrapper) + return + + def set_model_replace_all(self, patch, target="attn1"): + for block_name in ['input', 'middle', 'output']: + for number in range(16): + for transformer_index in range(16): + self.set_model_patch_replace(patch, target, block_name, number, transformer_index) + return + + def load_frozen_patcher(self, state_dict, strength): + patch_dict = {} + for k, w in state_dict.items(): + model_key, patch_type, weight_index = k.split('::') + if model_key not in patch_dict: + patch_dict[model_key] = {} + if patch_type not in patch_dict[model_key]: + patch_dict[model_key][patch_type] = [None] * 16 + patch_dict[model_key][patch_type][int(weight_index)] = w + + patch_flat = {} + for model_key, v in patch_dict.items(): + for patch_type, weight_list in v.items(): + patch_flat[model_key] = (patch_type, weight_list) + + self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) + return diff --git a/backend/patcher/vae.py b/backend/patcher/vae.py new file mode 100644 index 00000000..f750b3c6 --- /dev/null +++ b/backend/patcher/vae.py @@ -0,0 +1,223 @@ +import torch +import math +import itertools + +from tqdm import tqdm +from backend import memory_management +from backend.patcher.base import ModelPatcher + + +@torch.inference_mode() +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): + dims = len(tile) + output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device) + + for b in range(samples.shape[0]): + s = samples[b:b + 1] + out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) + out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) + + for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): + s_in = s + upscaled = [] + + for d in range(dims): + pos = max(0, min(s.shape[d + 2] - overlap, it[d])) + l = min(tile[d], s.shape[d + 2] - pos) + s_in = s_in.narrow(d + 2, pos, l) + upscaled.append(round(pos * upscale_amount)) + ps = function(s_in).to(output_device) + mask = torch.ones_like(ps) + feather = round(overlap * upscale_amount) + for t in range(feather): + for d in range(2, dims + 2): + m = mask.narrow(d, t, 1) + m *= ((1.0 / feather) * (t + 1)) + m = mask.narrow(d, mask.shape[d] - 1 - t, 1) + m *= ((1.0 / feather) * (t + 1)) + + o = out + o_d = out_div + for d in range(dims): + o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) + + o += ps * mask + o_d += mask + + if pbar is not None: + pbar.update(1) + + output[b:b + 1] = out / out_div + return output + + +def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): + return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) + + +def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None): + return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar) + + +class ProgressBar: + def __init__(self, total, title=None): + self.total = total + self.current = 0 + self.tqdm = tqdm(total=total, desc=title) + + def update_absolute(self, value, total=None, preview=None): + if total is not None: + self.total = total + if value > self.total: + value = self.total + inc = value - self.current + self.tqdm.update(inc) + self.current = value + if self.current >= self.total: + self.tqdm.close() + + def update(self, value): + self.update_absolute(self.current + value) + + +class VAE: + def __init__(self, model=None, device=None, dtype=None, no_init=False): + if no_init: + return + + self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * memory_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * memory_management.dtype_size(dtype) + self.downscale_ratio = int(2 ** (len(model.config.down_block_types) - 1)) + self.latent_channels = int(model.config.latent_channels) + + self.first_stage_model = model.eval() + + if device is None: + device = memory_management.vae_device() + + self.device = device + offload_device = memory_management.vae_offload_device() + + if dtype is None: + dtype = memory_management.vae_dtype() + + self.vae_dtype = dtype + self.first_stage_model.to(self.vae_dtype) + self.output_device = memory_management.intermediate_device() + + self.patcher = ModelPatcher( + self.first_stage_model, + load_device=self.device, + offload_device=offload_device + ) + + def clone(self): + n = VAE(no_init=True) + n.patcher = self.patcher.clone() + n.memory_used_encode = self.memory_used_encode + n.memory_used_decode = self.memory_used_decode + n.downscale_ratio = self.downscale_ratio + n.latent_channels = self.latent_channels + n.first_stage_model = self.first_stage_model + n.device = self.device + n.vae_dtype = self.vae_dtype + n.output_device = self.output_device + return n + + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16): + steps = samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) + steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = ProgressBar(steps, title='VAE tiled decode') + + decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float() + output = torch.clamp(((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) + + tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) + + tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar)) + / 3.0) / 2.0, min=0.0, max=1.0) + return output + + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64): + steps = pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) + steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) + steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) + pbar = ProgressBar(steps, title='VAE tiled encode') + + encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float() + samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) + samples /= 3.0 + return samples + + def decode_inner(self, samples_in): + if memory_management.VAE_ALWAYS_TILED: + return self.decode_tiled(samples_in).to(self.output_device) + + try: + memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) + memory_management.load_models_gpu([self.patcher], memory_required=memory_used) + free_memory = memory_management.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) + + pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device) + for x in range(0, samples_in.shape[0], batch_number): + samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device) + pixel_samples[x:x + batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0) + except memory_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") + pixel_samples = self.decode_tiled_(samples_in) + + pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1) + return pixel_samples + + def decode(self, samples_in): + wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None) + if wrapper is None: + return self.decode_inner(samples_in) + else: + return wrapper(self.decode_inner, samples_in) + + def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16): + memory_management.load_model_gpu(self.patcher) + output = self.decode_tiled_(samples, tile_x, tile_y, overlap) + return output.movedim(1, -1) + + def encode_inner(self, pixel_samples): + if memory_management.VAE_ALWAYS_TILED: + return self.encode_tiled(pixel_samples) + + regulation = self.patcher.model_options.get("model_vae_regulation", None) + + pixel_samples = pixel_samples.movedim(-1, 1) + try: + memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) + memory_management.load_models_gpu([self.patcher], memory_required=memory_used) + free_memory = memory_management.get_free_memory(self.device) + batch_number = int(free_memory / memory_used) + batch_number = max(1, batch_number) + samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device) + for x in range(0, pixel_samples.shape[0], batch_number): + pixels_in = (2. * pixel_samples[x:x + batch_number] - 1.).to(self.vae_dtype).to(self.device) + samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in, regulation).to(self.output_device).float() + + except memory_management.OOM_EXCEPTION as e: + print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") + samples = self.encode_tiled_(pixel_samples) + + return samples + + def encode(self, pixel_samples): + wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None) + if wrapper is None: + return self.encode_inner(pixel_samples) + else: + return wrapper(self.encode_inner, pixel_samples) + + def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64): + memory_management.load_model_gpu(self.patcher) + pixel_samples = pixel_samples.movedim(-1, 1) + samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap) + return samples diff --git a/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py b/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py index 47b020f2..c506240a 100644 --- a/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py +++ b/extensions-builtin/sd_forge_perturbed_attention/scripts/forge_perturbed_attention.py @@ -2,7 +2,7 @@ import gradio as gr import ldm_patched.modules.samplers from modules import scripts -from modules_forge.unet_patcher import copy_and_update_model_options +from backend.patcher.base import set_model_options_patch_replace class PerturbedAttentionGuidanceForForge(scripts.Script): @@ -36,7 +36,7 @@ class PerturbedAttentionGuidanceForForge(scripts.Script): model, cond_denoised, cond, denoised, sigma, x = \ args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"] - new_options = copy_and_update_model_options(args["model_options"], attn_proc, "attn1", "middle", 0) + new_options = set_model_options_patch_replace(args["model_options"], attn_proc, "attn1", "middle", 0) if scale == 0: return denoised diff --git a/ldm_patched/modules/sd.py b/ldm_patched/modules/sd.py index 9f97ee16..38377e83 100644 --- a/ldm_patched/modules/sd.py +++ b/ldm_patched/modules/sd.py @@ -97,9 +97,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen return model, clip -from backend.modules.clip import JointCLIP, JointTokenizer - - class CLIP: def __init__(self, huggingface_components=None, no_init=False): if no_init: diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index e60f2c78..cbe71e9f 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -4,7 +4,9 @@ import contextlib from ldm_patched.modules import model_management from ldm_patched.modules import model_detection -from ldm_patched.modules.sd import VAE, CLIP, load_model_weights +from ldm_patched.modules.sd import VAE, load_model_weights +from backend.patcher.clip import CLIP +from backend.patcher.vae import VAE import ldm_patched.modules.model_patcher import ldm_patched.modules.utils import ldm_patched.modules.clip_vision diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index dd0f0925..b7e6ee59 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -1,215 +1 @@ -import copy -import torch - -from ldm_patched.modules.model_patcher import ModelPatcher -from ldm_patched.modules.sample import convert_cond -from ldm_patched.modules.samplers import encode_model_conds -from ldm_patched.modules import model_management - - -class UnetPatcher(ModelPatcher): - def __init__(self, model, *args, **kwargs): - super().__init__(model, *args, **kwargs) - self.controlnet_linked_list = None - self.extra_preserved_memory_during_sampling = 0 - self.extra_model_patchers_during_sampling = [] - self.extra_concat_condition = None - - def clone(self): - n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, - weight_inplace_update=self.weight_inplace_update) - - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - - n.object_patches = self.object_patches.copy() - n.model_options = copy.deepcopy(self.model_options) - n.controlnet_linked_list = self.controlnet_linked_list - n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling - n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy() - n.extra_concat_condition = self.extra_concat_condition - return n - - def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int): - # Use this to ask Forge to preserve a certain amount of memory during sampling. - # If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024 - # Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM. - # You can estimate this using model_management.module_size(any_pytorch_model) to get size of any pytorch models. - self.extra_preserved_memory_during_sampling += memory_in_bytes - return - - def add_extra_model_patcher_during_sampling(self, model_patcher: ModelPatcher): - # Use this to ask Forge to move extra model patchers to GPU during sampling. - # This method will manage GPU memory perfectly. - self.extra_model_patchers_during_sampling.append(model_patcher) - return - - def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True): - # Use this method to bind an extra torch.nn.Module to this UNet during sampling. - # This model `m` will be delegated to Forge memory management system. - # `m` will be loaded to GPU everytime when sampling starts. - # `m` will be unloaded if necessary. - # `m` will influence Forge's judgement about use GPU memory or - # capacity and decide whether to use module offload to make user's batch size larger. - # Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling. - - if cast_to_unet_dtype: - m.to(self.model.diffusion_model.dtype) - - patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device) - - self.add_extra_model_patcher_during_sampling(patcher) - return patcher - - def add_patched_controlnet(self, cnet): - cnet.set_previous_controlnet(self.controlnet_linked_list) - self.controlnet_linked_list = cnet - return - - def list_controlnets(self): - results = [] - pointer = self.controlnet_linked_list - while pointer is not None: - results.append(pointer) - pointer = pointer.previous_controlnet - return results - - def append_model_option(self, k, v, ensure_uniqueness=False): - if k not in self.model_options: - self.model_options[k] = [] - - if ensure_uniqueness and v in self.model_options[k]: - return - - self.model_options[k].append(v) - return - - def append_transformer_option(self, k, v, ensure_uniqueness=False): - if 'transformer_options' not in self.model_options: - self.model_options['transformer_options'] = {} - - to = self.model_options['transformer_options'] - - if k not in to: - to[k] = [] - - if ensure_uniqueness and v in to[k]: - return - - to[k].append(v) - return - - def set_transformer_option(self, k, v): - if 'transformer_options' not in self.model_options: - self.model_options['transformer_options'] = {} - - self.model_options['transformer_options'][k] = v - return - - def add_conditioning_modifier(self, modifier, ensure_uniqueness=False): - self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness) - return - - def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False): - self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness) - return - - def set_memory_peak_estimation_modifier(self, modifier): - self.model_options['memory_peak_estimation_modifier'] = modifier - return - - def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False): - """ - - For some reasons, this function only works in A1111's Script.process_batch(self, p, *args, **kwargs) - - For example, below is a worked modification: - - class ExampleScript(scripts.Script): - - def process_batch(self, p, *args, **kwargs): - unet = p.sd_model.forge_objects.unet.clone() - - def modifier(x): - return x ** 0.5 - - unet.add_alphas_cumprod_modifier(modifier) - p.sd_model.forge_objects.unet = unet - - return - - This add_alphas_cumprod_modifier is the only patch option that should be used in process_batch() - All other patch options should be called in process_before_every_sampling() - - """ - - self.append_model_option('alphas_cumprod_modifiers', modifier, ensure_uniqueness) - return - - def add_block_modifier(self, modifier, ensure_uniqueness=False): - self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness) - return - - def add_block_inner_modifier(self, modifier, ensure_uniqueness=False): - self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness) - return - - def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False): - self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness) - return - - def set_groupnorm_wrapper(self, wrapper): - self.set_transformer_option('groupnorm_wrapper', wrapper) - return - - def set_controlnet_model_function_wrapper(self, wrapper): - self.set_transformer_option('controlnet_model_function_wrapper', wrapper) - return - - def set_model_replace_all(self, patch, target="attn1"): - for block_name in ['input', 'middle', 'output']: - for number in range(16): - for transformer_index in range(16): - self.set_model_patch_replace(patch, target, block_name, number, transformer_index) - return - - def encode_conds_after_clip(self, conds, noise, prompt_type="positive"): - return encode_model_conds( - model_function=self.model.extra_conds, - conds=convert_cond(conds), - noise=noise, - device=noise.device, - prompt_type=prompt_type - ) - - def load_frozen_patcher(self, state_dict, strength): - patch_dict = {} - for k, w in state_dict.items(): - model_key, patch_type, weight_index = k.split('::') - if model_key not in patch_dict: - patch_dict[model_key] = {} - if patch_type not in patch_dict[model_key]: - patch_dict[model_key][patch_type] = [None] * 16 - patch_dict[model_key][patch_type][int(weight_index)] = w - - patch_flat = {} - for model_key, v in patch_dict.items(): - for patch_type, weight_list in v.items(): - patch_flat[model_key] = (patch_type, weight_list) - - self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0) - return - - -def copy_and_update_model_options(model_options, patch, name, block_name, number, transformer_index=None): - model_options = model_options.copy() - transformer_options = model_options.get("transformer_options", {}).copy() - patches_replace = transformer_options.get("patches_replace", {}).copy() - name_patches = patches_replace.get(name, {}).copy() - block = (block_name, number, transformer_index) if transformer_index is not None else (block_name, number) - name_patches[block] = patch - patches_replace[name] = name_patches - transformer_options["patches_replace"] = patches_replace - model_options["transformer_options"] = transformer_options - return model_options +from backend.patcher.unet import *