mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-21 15:23:58 +00:00
rework several component patcher
backend is 65% finished
This commit is contained in:
@@ -1,14 +0,0 @@
|
||||
import torch
|
||||
|
||||
|
||||
class JointTokenizer:
|
||||
def __init__(self, huggingface_components):
|
||||
self.clip_l = huggingface_components.get('tokenizer', None)
|
||||
self.clip_g = huggingface_components.get('tokenizer_2', None)
|
||||
|
||||
|
||||
class JointCLIP(torch.nn.Module):
|
||||
def __init__(self, huggingface_components):
|
||||
super().__init__()
|
||||
self.clip_l = huggingface_components.get('text_encoder', None)
|
||||
self.clip_g = huggingface_components.get('text_encoder_2', None)
|
||||
42
backend/patcher/clip.py
Normal file
42
backend/patcher/clip.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
class JointTokenizer:
|
||||
def __init__(self, huggingface_components):
|
||||
self.clip_l = huggingface_components.get('tokenizer', None)
|
||||
self.clip_g = huggingface_components.get('tokenizer_2', None)
|
||||
|
||||
|
||||
class JointCLIPTextEncoder(torch.nn.Module):
|
||||
def __init__(self, huggingface_components):
|
||||
super().__init__()
|
||||
self.clip_l = huggingface_components.get('text_encoder', None)
|
||||
self.clip_g = huggingface_components.get('text_encoder_2', None)
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, huggingface_components=None, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
|
||||
load_device = memory_management.text_encoder_device()
|
||||
offload_device = memory_management.text_encoder_offload_device()
|
||||
text_encoder_dtype = memory_management.text_encoder_dtype(load_device)
|
||||
|
||||
self.cond_stage_model = JointCLIPTextEncoder(huggingface_components)
|
||||
self.tokenizer = JointTokenizer(huggingface_components)
|
||||
self.cond_stage_model.to(dtype=text_encoder_dtype, device=offload_device)
|
||||
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
n.patcher = self.patcher.clone()
|
||||
n.cond_stage_model = self.cond_stage_model
|
||||
n.tokenizer = self.tokenizer
|
||||
return n
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
190
backend/patcher/unet.py
Normal file
190
backend/patcher/unet.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
self.extra_preserved_memory_during_sampling = 0
|
||||
self.extra_model_patchers_during_sampling = []
|
||||
self.extra_concat_condition = None
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
weight_inplace_update=self.weight_inplace_update)
|
||||
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||
n.extra_concat_condition = self.extra_concat_condition
|
||||
return n
|
||||
|
||||
def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int):
|
||||
# Use this to ask Forge to preserve a certain amount of memory during sampling.
|
||||
# If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024
|
||||
# Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM.
|
||||
# You can estimate this using memory_management.module_size(any_pytorch_model) to get size of any pytorch models.
|
||||
self.extra_preserved_memory_during_sampling += memory_in_bytes
|
||||
return
|
||||
|
||||
def add_extra_model_patcher_during_sampling(self, model_patcher: ModelPatcher):
|
||||
# Use this to ask Forge to move extra model patchers to GPU during sampling.
|
||||
# This method will manage GPU memory perfectly.
|
||||
self.extra_model_patchers_during_sampling.append(model_patcher)
|
||||
return
|
||||
|
||||
def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True):
|
||||
# Use this method to bind an extra torch.nn.Module to this UNet during sampling.
|
||||
# This model `m` will be delegated to Forge memory management system.
|
||||
# `m` will be loaded to GPU everytime when sampling starts.
|
||||
# `m` will be unloaded if necessary.
|
||||
# `m` will influence Forge's judgement about use GPU memory or
|
||||
# capacity and decide whether to use module offload to make user's batch size larger.
|
||||
# Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling.
|
||||
|
||||
if cast_to_unet_dtype:
|
||||
m.to(self.model.diffusion_model.dtype)
|
||||
|
||||
patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
self.add_extra_model_patcher_during_sampling(patcher)
|
||||
return patcher
|
||||
|
||||
def add_patched_controlnet(self, cnet):
|
||||
cnet.set_previous_controlnet(self.controlnet_linked_list)
|
||||
self.controlnet_linked_list = cnet
|
||||
return
|
||||
|
||||
def list_controlnets(self):
|
||||
results = []
|
||||
pointer = self.controlnet_linked_list
|
||||
while pointer is not None:
|
||||
results.append(pointer)
|
||||
pointer = pointer.previous_controlnet
|
||||
return results
|
||||
|
||||
def append_model_option(self, k, v, ensure_uniqueness=False):
|
||||
if k not in self.model_options:
|
||||
self.model_options[k] = []
|
||||
|
||||
if ensure_uniqueness and v in self.model_options[k]:
|
||||
return
|
||||
|
||||
self.model_options[k].append(v)
|
||||
return
|
||||
|
||||
def append_transformer_option(self, k, v, ensure_uniqueness=False):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
to = self.model_options['transformer_options']
|
||||
|
||||
if k not in to:
|
||||
to[k] = []
|
||||
|
||||
if ensure_uniqueness and v in to[k]:
|
||||
return
|
||||
|
||||
to[k].append(v)
|
||||
return
|
||||
|
||||
def set_transformer_option(self, k, v):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
self.model_options['transformer_options'][k] = v
|
||||
return
|
||||
|
||||
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_memory_peak_estimation_modifier(self, modifier):
|
||||
self.model_options['memory_peak_estimation_modifier'] = modifier
|
||||
return
|
||||
|
||||
def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False):
|
||||
"""
|
||||
|
||||
For some reasons, this function only works in A1111's Script.process_batch(self, p, *args, **kwargs)
|
||||
|
||||
For example, below is a worked modification:
|
||||
|
||||
class ExampleScript(scripts.Script):
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
unet = p.sd_model.forge_objects.unet.clone()
|
||||
|
||||
def modifier(x):
|
||||
return x ** 0.5
|
||||
|
||||
unet.add_alphas_cumprod_modifier(modifier)
|
||||
p.sd_model.forge_objects.unet = unet
|
||||
|
||||
return
|
||||
|
||||
This add_alphas_cumprod_modifier is the only patch option that should be used in process_batch()
|
||||
All other patch options should be called in process_before_every_sampling()
|
||||
|
||||
"""
|
||||
|
||||
self.append_model_option('alphas_cumprod_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_inner_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_groupnorm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('groupnorm_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_model_replace_all(self, patch, target="attn1"):
|
||||
for block_name in ['input', 'middle', 'output']:
|
||||
for number in range(16):
|
||||
for transformer_index in range(16):
|
||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||
return
|
||||
|
||||
def load_frozen_patcher(self, state_dict, strength):
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
if model_key not in patch_dict:
|
||||
patch_dict[model_key] = {}
|
||||
if patch_type not in patch_dict[model_key]:
|
||||
patch_dict[model_key][patch_type] = [None] * 16
|
||||
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||
|
||||
patch_flat = {}
|
||||
for model_key, v in patch_dict.items():
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
223
backend/patcher/vae.py
Normal file
223
backend/patcher/vae.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import torch
|
||||
import math
|
||||
import itertools
|
||||
|
||||
from tqdm import tqdm
|
||||
from backend import memory_management
|
||||
from backend.patcher.base import ModelPatcher
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
dims = len(tile)
|
||||
output = torch.empty([samples.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), samples.shape[2:])), device=output_device)
|
||||
|
||||
for b in range(samples.shape[0]):
|
||||
s = samples[b:b + 1]
|
||||
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
|
||||
|
||||
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
|
||||
s_in = s
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - overlap, it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(pos * upscale_amount))
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
feather = round(overlap * upscale_amount)
|
||||
for t in range(feather):
|
||||
for d in range(2, dims + 2):
|
||||
m = mask.narrow(d, t, 1)
|
||||
m *= ((1.0 / feather) * (t + 1))
|
||||
m = mask.narrow(d, mask.shape[d] - 1 - t, 1)
|
||||
m *= ((1.0 / feather) * (t + 1))
|
||||
|
||||
o = out
|
||||
o_d = out_div
|
||||
for d in range(dims):
|
||||
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
|
||||
|
||||
o += ps * mask
|
||||
o_d += mask
|
||||
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
|
||||
output[b:b + 1] = out / out_div
|
||||
return output
|
||||
|
||||
|
||||
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
|
||||
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||
|
||||
|
||||
class ProgressBar:
|
||||
def __init__(self, total, title=None):
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.tqdm = tqdm(total=total, desc=title)
|
||||
|
||||
def update_absolute(self, value, total=None, preview=None):
|
||||
if total is not None:
|
||||
self.total = total
|
||||
if value > self.total:
|
||||
value = self.total
|
||||
inc = value - self.current
|
||||
self.tqdm.update(inc)
|
||||
self.current = value
|
||||
if self.current >= self.total:
|
||||
self.tqdm.close()
|
||||
|
||||
def update(self, value):
|
||||
self.update_absolute(self.current + value)
|
||||
|
||||
|
||||
class VAE:
|
||||
def __init__(self, model=None, device=None, dtype=None, no_init=False):
|
||||
if no_init:
|
||||
return
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * memory_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * memory_management.dtype_size(dtype)
|
||||
self.downscale_ratio = int(2 ** (len(model.config.down_block_types) - 1))
|
||||
self.latent_channels = int(model.config.latent_channels)
|
||||
|
||||
self.first_stage_model = model.eval()
|
||||
|
||||
if device is None:
|
||||
device = memory_management.vae_device()
|
||||
|
||||
self.device = device
|
||||
offload_device = memory_management.vae_offload_device()
|
||||
|
||||
if dtype is None:
|
||||
dtype = memory_management.vae_dtype()
|
||||
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
self.output_device = memory_management.intermediate_device()
|
||||
|
||||
self.patcher = ModelPatcher(
|
||||
self.first_stage_model,
|
||||
load_device=self.device,
|
||||
offload_device=offload_device
|
||||
)
|
||||
|
||||
def clone(self):
|
||||
n = VAE(no_init=True)
|
||||
n.patcher = self.patcher.clone()
|
||||
n.memory_used_encode = self.memory_used_encode
|
||||
n.memory_used_decode = self.memory_used_decode
|
||||
n.downscale_ratio = self.downscale_ratio
|
||||
n.latent_channels = self.latent_channels
|
||||
n.first_stage_model = self.first_stage_model
|
||||
n.device = self.device
|
||||
n.vae_dtype = self.vae_dtype
|
||||
n.output_device = self.output_device
|
||||
return n
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||
steps = samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ProgressBar(steps, title='VAE tiled decode')
|
||||
|
||||
decode_fn = lambda a: (self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)) + 1.0).float()
|
||||
output = torch.clamp(((tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar) +
|
||||
tiled_scale(samples, decode_fn, tile_x, tile_y, overlap, upscale_amount=self.downscale_ratio, output_device=self.output_device, pbar=pbar))
|
||||
/ 3.0) / 2.0, min=0.0, max=1.0)
|
||||
return output
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
steps = pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += pixel_samples.shape[0] * get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = ProgressBar(steps, title='VAE tiled encode')
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((2. * a - 1.).to(self.vae_dtype).to(self.device)).float()
|
||||
samples = tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples /= 3.0
|
||||
return samples
|
||||
|
||||
def decode_inner(self, samples_in):
|
||||
if memory_management.VAE_ALWAYS_TILED:
|
||||
return self.decode_tiled(samples_in).to(self.output_device)
|
||||
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = memory_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
|
||||
pixel_samples = torch.empty((samples_in.shape[0], 3, round(samples_in.shape[2] * self.downscale_ratio), round(samples_in.shape[3] * self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x + batch_number].to(self.vae_dtype).to(self.device)
|
||||
pixel_samples[x:x + batch_number] = torch.clamp((self.first_stage_model.decode(samples).to(self.output_device).float() + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
|
||||
return pixel_samples
|
||||
|
||||
def decode(self, samples_in):
|
||||
wrapper = self.patcher.model_options.get('model_vae_decode_wrapper', None)
|
||||
if wrapper is None:
|
||||
return self.decode_inner(samples_in)
|
||||
else:
|
||||
return wrapper(self.decode_inner, samples_in)
|
||||
|
||||
def decode_tiled(self, samples, tile_x=64, tile_y=64, overlap=16):
|
||||
memory_management.load_model_gpu(self.patcher)
|
||||
output = self.decode_tiled_(samples, tile_x, tile_y, overlap)
|
||||
return output.movedim(1, -1)
|
||||
|
||||
def encode_inner(self, pixel_samples):
|
||||
if memory_management.VAE_ALWAYS_TILED:
|
||||
return self.encode_tiled(pixel_samples)
|
||||
|
||||
regulation = self.patcher.model_options.get("model_vae_regulation", None)
|
||||
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
try:
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
|
||||
memory_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
free_memory = memory_management.get_free_memory(self.device)
|
||||
batch_number = int(free_memory / memory_used)
|
||||
batch_number = max(1, batch_number)
|
||||
samples = torch.empty((pixel_samples.shape[0], self.latent_channels, round(pixel_samples.shape[2] // self.downscale_ratio), round(pixel_samples.shape[3] // self.downscale_ratio)), device=self.output_device)
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = (2. * pixel_samples[x:x + batch_number] - 1.).to(self.vae_dtype).to(self.device)
|
||||
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in, regulation).to(self.output_device).float()
|
||||
|
||||
except memory_management.OOM_EXCEPTION as e:
|
||||
print("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
def encode(self, pixel_samples):
|
||||
wrapper = self.patcher.model_options.get('model_vae_encode_wrapper', None)
|
||||
if wrapper is None:
|
||||
return self.encode_inner(pixel_samples)
|
||||
else:
|
||||
return wrapper(self.encode_inner, pixel_samples)
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
memory_management.load_model_gpu(self.patcher)
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
return samples
|
||||
@@ -2,7 +2,7 @@ import gradio as gr
|
||||
import ldm_patched.modules.samplers
|
||||
|
||||
from modules import scripts
|
||||
from modules_forge.unet_patcher import copy_and_update_model_options
|
||||
from backend.patcher.base import set_model_options_patch_replace
|
||||
|
||||
|
||||
class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||
@@ -36,7 +36,7 @@ class PerturbedAttentionGuidanceForForge(scripts.Script):
|
||||
model, cond_denoised, cond, denoised, sigma, x = \
|
||||
args["model"], args["cond_denoised"], args["cond"], args["denoised"], args["sigma"], args["input"]
|
||||
|
||||
new_options = copy_and_update_model_options(args["model_options"], attn_proc, "attn1", "middle", 0)
|
||||
new_options = set_model_options_patch_replace(args["model_options"], attn_proc, "attn1", "middle", 0)
|
||||
|
||||
if scale == 0:
|
||||
return denoised
|
||||
|
||||
@@ -97,9 +97,6 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filen
|
||||
return model, clip
|
||||
|
||||
|
||||
from backend.modules.clip import JointCLIP, JointTokenizer
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, huggingface_components=None, no_init=False):
|
||||
if no_init:
|
||||
|
||||
@@ -4,7 +4,9 @@ import contextlib
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules import model_detection
|
||||
|
||||
from ldm_patched.modules.sd import VAE, CLIP, load_model_weights
|
||||
from ldm_patched.modules.sd import VAE, load_model_weights
|
||||
from backend.patcher.clip import CLIP
|
||||
from backend.patcher.vae import VAE
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.utils
|
||||
import ldm_patched.modules.clip_vision
|
||||
|
||||
@@ -1,215 +1 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from ldm_patched.modules.sample import convert_cond
|
||||
from ldm_patched.modules.samplers import encode_model_conds
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
class UnetPatcher(ModelPatcher):
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
super().__init__(model, *args, **kwargs)
|
||||
self.controlnet_linked_list = None
|
||||
self.extra_preserved_memory_during_sampling = 0
|
||||
self.extra_model_patchers_during_sampling = []
|
||||
self.extra_concat_condition = None
|
||||
|
||||
def clone(self):
|
||||
n = UnetPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device,
|
||||
weight_inplace_update=self.weight_inplace_update)
|
||||
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.controlnet_linked_list = self.controlnet_linked_list
|
||||
n.extra_preserved_memory_during_sampling = self.extra_preserved_memory_during_sampling
|
||||
n.extra_model_patchers_during_sampling = self.extra_model_patchers_during_sampling.copy()
|
||||
n.extra_concat_condition = self.extra_concat_condition
|
||||
return n
|
||||
|
||||
def add_extra_preserved_memory_during_sampling(self, memory_in_bytes: int):
|
||||
# Use this to ask Forge to preserve a certain amount of memory during sampling.
|
||||
# If GPU VRAM is 8 GB, and memory_in_bytes is 2GB, i.e., memory_in_bytes = 2 * 1024 * 1024 * 1024
|
||||
# Then the sampling will always use less than 6GB memory by dynamically offload modules to CPU RAM.
|
||||
# You can estimate this using model_management.module_size(any_pytorch_model) to get size of any pytorch models.
|
||||
self.extra_preserved_memory_during_sampling += memory_in_bytes
|
||||
return
|
||||
|
||||
def add_extra_model_patcher_during_sampling(self, model_patcher: ModelPatcher):
|
||||
# Use this to ask Forge to move extra model patchers to GPU during sampling.
|
||||
# This method will manage GPU memory perfectly.
|
||||
self.extra_model_patchers_during_sampling.append(model_patcher)
|
||||
return
|
||||
|
||||
def add_extra_torch_module_during_sampling(self, m: torch.nn.Module, cast_to_unet_dtype: bool = True):
|
||||
# Use this method to bind an extra torch.nn.Module to this UNet during sampling.
|
||||
# This model `m` will be delegated to Forge memory management system.
|
||||
# `m` will be loaded to GPU everytime when sampling starts.
|
||||
# `m` will be unloaded if necessary.
|
||||
# `m` will influence Forge's judgement about use GPU memory or
|
||||
# capacity and decide whether to use module offload to make user's batch size larger.
|
||||
# Use cast_to_unet_dtype if you want `m` to have same dtype with unet during sampling.
|
||||
|
||||
if cast_to_unet_dtype:
|
||||
m.to(self.model.diffusion_model.dtype)
|
||||
|
||||
patcher = ModelPatcher(model=m, load_device=self.load_device, offload_device=self.offload_device)
|
||||
|
||||
self.add_extra_model_patcher_during_sampling(patcher)
|
||||
return patcher
|
||||
|
||||
def add_patched_controlnet(self, cnet):
|
||||
cnet.set_previous_controlnet(self.controlnet_linked_list)
|
||||
self.controlnet_linked_list = cnet
|
||||
return
|
||||
|
||||
def list_controlnets(self):
|
||||
results = []
|
||||
pointer = self.controlnet_linked_list
|
||||
while pointer is not None:
|
||||
results.append(pointer)
|
||||
pointer = pointer.previous_controlnet
|
||||
return results
|
||||
|
||||
def append_model_option(self, k, v, ensure_uniqueness=False):
|
||||
if k not in self.model_options:
|
||||
self.model_options[k] = []
|
||||
|
||||
if ensure_uniqueness and v in self.model_options[k]:
|
||||
return
|
||||
|
||||
self.model_options[k].append(v)
|
||||
return
|
||||
|
||||
def append_transformer_option(self, k, v, ensure_uniqueness=False):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
to = self.model_options['transformer_options']
|
||||
|
||||
if k not in to:
|
||||
to[k] = []
|
||||
|
||||
if ensure_uniqueness and v in to[k]:
|
||||
return
|
||||
|
||||
to[k].append(v)
|
||||
return
|
||||
|
||||
def set_transformer_option(self, k, v):
|
||||
if 'transformer_options' not in self.model_options:
|
||||
self.model_options['transformer_options'] = {}
|
||||
|
||||
self.model_options['transformer_options'][k] = v
|
||||
return
|
||||
|
||||
def add_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_sampler_pre_cfg_function(self, modifier, ensure_uniqueness=False):
|
||||
self.append_model_option('sampler_pre_cfg_function', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_memory_peak_estimation_modifier(self, modifier):
|
||||
self.model_options['memory_peak_estimation_modifier'] = modifier
|
||||
return
|
||||
|
||||
def add_alphas_cumprod_modifier(self, modifier, ensure_uniqueness=False):
|
||||
"""
|
||||
|
||||
For some reasons, this function only works in A1111's Script.process_batch(self, p, *args, **kwargs)
|
||||
|
||||
For example, below is a worked modification:
|
||||
|
||||
class ExampleScript(scripts.Script):
|
||||
|
||||
def process_batch(self, p, *args, **kwargs):
|
||||
unet = p.sd_model.forge_objects.unet.clone()
|
||||
|
||||
def modifier(x):
|
||||
return x ** 0.5
|
||||
|
||||
unet.add_alphas_cumprod_modifier(modifier)
|
||||
p.sd_model.forge_objects.unet = unet
|
||||
|
||||
return
|
||||
|
||||
This add_alphas_cumprod_modifier is the only patch option that should be used in process_batch()
|
||||
All other patch options should be called in process_before_every_sampling()
|
||||
|
||||
"""
|
||||
|
||||
self.append_model_option('alphas_cumprod_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('block_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_block_inner_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('block_inner_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def add_controlnet_conditioning_modifier(self, modifier, ensure_uniqueness=False):
|
||||
self.append_transformer_option('controlnet_conditioning_modifiers', modifier, ensure_uniqueness)
|
||||
return
|
||||
|
||||
def set_groupnorm_wrapper(self, wrapper):
|
||||
self.set_transformer_option('groupnorm_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_controlnet_model_function_wrapper(self, wrapper):
|
||||
self.set_transformer_option('controlnet_model_function_wrapper', wrapper)
|
||||
return
|
||||
|
||||
def set_model_replace_all(self, patch, target="attn1"):
|
||||
for block_name in ['input', 'middle', 'output']:
|
||||
for number in range(16):
|
||||
for transformer_index in range(16):
|
||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||
return
|
||||
|
||||
def encode_conds_after_clip(self, conds, noise, prompt_type="positive"):
|
||||
return encode_model_conds(
|
||||
model_function=self.model.extra_conds,
|
||||
conds=convert_cond(conds),
|
||||
noise=noise,
|
||||
device=noise.device,
|
||||
prompt_type=prompt_type
|
||||
)
|
||||
|
||||
def load_frozen_patcher(self, state_dict, strength):
|
||||
patch_dict = {}
|
||||
for k, w in state_dict.items():
|
||||
model_key, patch_type, weight_index = k.split('::')
|
||||
if model_key not in patch_dict:
|
||||
patch_dict[model_key] = {}
|
||||
if patch_type not in patch_dict[model_key]:
|
||||
patch_dict[model_key][patch_type] = [None] * 16
|
||||
patch_dict[model_key][patch_type][int(weight_index)] = w
|
||||
|
||||
patch_flat = {}
|
||||
for model_key, v in patch_dict.items():
|
||||
for patch_type, weight_list in v.items():
|
||||
patch_flat[model_key] = (patch_type, weight_list)
|
||||
|
||||
self.add_patches(patches=patch_flat, strength_patch=float(strength), strength_model=1.0)
|
||||
return
|
||||
|
||||
|
||||
def copy_and_update_model_options(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
model_options = model_options.copy()
|
||||
transformer_options = model_options.get("transformer_options", {}).copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {}).copy()
|
||||
name_patches = patches_replace.get(name, {}).copy()
|
||||
block = (block_name, number, transformer_index) if transformer_index is not None else (block_name, number)
|
||||
name_patches[block] = patch
|
||||
patches_replace[name] = name_patches
|
||||
transformer_options["patches_replace"] = patches_replace
|
||||
model_options["transformer_options"] = transformer_options
|
||||
return model_options
|
||||
from backend.patcher.unet import *
|
||||
|
||||
Reference in New Issue
Block a user