mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-19 22:34:00 +00:00
backend
This commit is contained in:
@@ -825,6 +825,7 @@ class UNetModel(nn.Module):
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
block_modifiers = transformer_options.get("block_modifiers", [])
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
@@ -844,8 +845,16 @@ class UNetModel(nn.Module):
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
@@ -858,9 +867,15 @@ class UNetModel(nn.Module):
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
@@ -878,9 +893,26 @@ class UNetModel(nn.Module):
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = h.type(x.dtype)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
transformer_options["block"] = ("last", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
h = self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
h = self.out(h)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
return h.type(x.dtype)
|
||||
|
||||
@@ -225,19 +225,13 @@ class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
# Consistent with Kohya to reduce differences between model training and inference.
|
||||
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
|
||||
)
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
||||
@@ -11,6 +11,9 @@ import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.t2ia.adapter
|
||||
|
||||
|
||||
compute_controlnet_weighting = None
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
#print(current_batch_size, target_batch_size)
|
||||
@@ -114,6 +117,10 @@ class ControlBase:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
|
||||
if compute_controlnet_weighting is not None:
|
||||
out = compute_controlnet_weighting(out, self)
|
||||
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
o = out[x]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import psutil
|
||||
from enum import Enum
|
||||
from ldm_patched.modules.args_parser import args
|
||||
@@ -42,8 +43,6 @@ if args.directml is not None:
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
print("Using directml with device:", torch_directml.device_name(device_index))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
@@ -128,6 +127,9 @@ try:
|
||||
except:
|
||||
OOM_EXCEPTION = Exception
|
||||
|
||||
if directml_enabled:
|
||||
OOM_EXCEPTION = Exception
|
||||
|
||||
XFORMERS_VERSION = ""
|
||||
XFORMERS_ENABLED_VAE = True
|
||||
if args.disable_xformers:
|
||||
@@ -376,6 +378,8 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
global vram_state
|
||||
|
||||
execution_start_time = time.perf_counter()
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required)
|
||||
|
||||
@@ -390,7 +394,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
if hasattr(x, "model"):
|
||||
print(f"Requested to load {x.model.__class__.__name__}")
|
||||
print(f"To load target model {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
|
||||
if len(models_to_load) == 0:
|
||||
@@ -398,9 +402,14 @@ def load_models_gpu(models, memory_required=0):
|
||||
for d in devs:
|
||||
if d != torch.device("cpu"):
|
||||
free_memory(extra_mem, d, models_already_loaded)
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
if moving_time > 0.1:
|
||||
print(f'Moving model(s) skipped. Freeing memory has taken {moving_time:.2f} seconds')
|
||||
|
||||
return
|
||||
|
||||
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
|
||||
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
|
||||
|
||||
total_memory_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
@@ -433,6 +442,11 @@ def load_models_gpu(models, memory_required=0):
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
if moving_time > 0.1:
|
||||
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,24 @@
|
||||
import torch
|
||||
import ldm_patched.modules.model_management
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
try:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, getattr(operations, op_name))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, backups[op_name])
|
||||
return
|
||||
|
||||
|
||||
def cast_bias_weight(s, input):
|
||||
bias = None
|
||||
|
||||
@@ -126,6 +126,29 @@ def cond_cat(c_list):
|
||||
|
||||
return out
|
||||
|
||||
def compute_cond_mark(cond_or_uncond, sigmas):
|
||||
cond_or_uncond_size = int(sigmas.shape[0])
|
||||
|
||||
cond_mark = []
|
||||
for cx in cond_or_uncond:
|
||||
cond_mark += [cx] * cond_or_uncond_size
|
||||
|
||||
cond_mark = torch.Tensor(cond_mark).to(sigmas)
|
||||
return cond_mark
|
||||
|
||||
def compute_cond_indices(cond_or_uncond, sigmas):
|
||||
cl = int(sigmas.shape[0])
|
||||
|
||||
cond_indices = []
|
||||
uncond_indices = []
|
||||
for i, cx in enumerate(cond_or_uncond):
|
||||
if cx == 0:
|
||||
cond_indices += list(range(i * cl, (i + 1) * cl))
|
||||
else:
|
||||
uncond_indices += list(range(i * cl, (i + 1) * cl))
|
||||
|
||||
return cond_indices, uncond_indices
|
||||
|
||||
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
@@ -193,9 +216,6 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
if control is not None:
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
@@ -214,8 +234,18 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
p = control
|
||||
while p is not None:
|
||||
p.transformer_options = transformer_options
|
||||
p = p.previous_controlnet
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
|
||||
@@ -8,6 +8,7 @@ import zipfile
|
||||
from . import model_management
|
||||
import ldm_patched.modules.clip_model
|
||||
import json
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
start_token = special_tokens.get("start", None)
|
||||
@@ -74,11 +75,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
|
||||
with open(textmodel_json_config) as f:
|
||||
config = json.load(f)
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.transformer = model_class(config, dtype, device, ldm_patched.modules.ops.manual_cast)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
with ldm_patched.modules.ops.use_patched_ops(ldm_patched.modules.ops.manual_cast):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
|
||||
self.transformer.text_model.embeddings.to(torch.float32)
|
||||
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
@@ -169,16 +176,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
|
||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
|
||||
output_hidden_states=self.layer == "hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs[0]
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs[1]
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
if outputs[2] is not None:
|
||||
pooled_output = outputs[2].float()
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
|
||||
@@ -33,19 +33,13 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
|
||||
|
||||
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
|
||||
|
||||
output_directory = os.path.join(os.getcwd(), "output")
|
||||
temp_directory = os.path.join(os.getcwd(), "temp")
|
||||
input_directory = os.path.join(os.getcwd(), "input")
|
||||
user_directory = os.path.join(os.getcwd(), "user")
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
user_directory = os.path.join(base_path, "user")
|
||||
|
||||
filename_list_cache = {}
|
||||
|
||||
if not os.path.exists(input_directory):
|
||||
try:
|
||||
pass # os.makedirs(input_directory)
|
||||
except:
|
||||
print("Failed to create input directory")
|
||||
|
||||
def set_output_directory(output_dir):
|
||||
global output_directory
|
||||
output_directory = output_dir
|
||||
|
||||
@@ -66,14 +66,14 @@ def apply_controlnet_advanced(
|
||||
return m
|
||||
|
||||
|
||||
def compute_controlnet_weighting(
|
||||
control,
|
||||
positive_advanced_weighting,
|
||||
negative_advanced_weighting,
|
||||
advanced_frame_weighting,
|
||||
advanced_sigma_weighting,
|
||||
transformer_options
|
||||
):
|
||||
def compute_controlnet_weighting(control, cnet):
|
||||
|
||||
positive_advanced_weighting = cnet.positive_advanced_weighting
|
||||
negative_advanced_weighting = cnet.negative_advanced_weighting
|
||||
advanced_frame_weighting = cnet.advanced_frame_weighting
|
||||
advanced_sigma_weighting = cnet.advanced_sigma_weighting
|
||||
transformer_options = cnet.transformer_options
|
||||
|
||||
if positive_advanced_weighting is None and negative_advanced_weighting is None \
|
||||
and advanced_frame_weighting is None and advanced_sigma_weighting is None:
|
||||
return control
|
||||
|
||||
@@ -24,22 +24,9 @@ def initialize_forge():
|
||||
torch.zeros((1, 1)).to(device, torch.float32)
|
||||
model_management.soft_empty_cache()
|
||||
|
||||
import modules_forge.patch_clip
|
||||
modules_forge.patch_clip.patch_all_clip()
|
||||
|
||||
import modules_forge.patch_precision
|
||||
modules_forge.patch_precision.patch_all_precision()
|
||||
|
||||
import modules_forge.patch_basic
|
||||
modules_forge.patch_basic.patch_all_basics()
|
||||
|
||||
import modules_forge.unet_patcher
|
||||
modules_forge.unet_patcher.patch_all()
|
||||
|
||||
if model_management.directml_enabled:
|
||||
model_management.lowvram_available = True
|
||||
model_management.OOM_EXCEPTION = Exception
|
||||
|
||||
from modules_forge import supported_preprocessor
|
||||
from modules_forge import supported_controlnet
|
||||
|
||||
|
||||
@@ -2,23 +2,7 @@ import time
|
||||
import torch
|
||||
import contextlib
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
try:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, getattr(operations, op_name))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, backups[op_name])
|
||||
return
|
||||
from ldm_patched.modules.ops import use_patched_ops
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@@ -1,201 +1,6 @@
|
||||
import torch
|
||||
import os
|
||||
import time
|
||||
import safetensors
|
||||
import ldm_patched.modules.samplers
|
||||
|
||||
from ldm_patched.modules.controlnet import ControlBase
|
||||
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
|
||||
from ldm_patched.modules import model_management
|
||||
from modules_forge.controlnet import compute_controlnet_weighting
|
||||
from modules_forge.forge_util import compute_cond_mark, compute_cond_indices
|
||||
|
||||
|
||||
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||
out = {'input': [], 'middle': [], 'output': []}
|
||||
|
||||
if control_input is not None:
|
||||
for i in range(len(control_input)):
|
||||
key = 'input'
|
||||
x = control_input[i]
|
||||
if x is not None:
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
out[key].insert(0, x)
|
||||
|
||||
if control_output is not None:
|
||||
for i in range(len(control_output)):
|
||||
if i == (len(control_output) - 1):
|
||||
key = 'middle'
|
||||
index = 0
|
||||
else:
|
||||
key = 'output'
|
||||
index = i
|
||||
x = control_output[i]
|
||||
if x is not None:
|
||||
if self.global_average_pooling:
|
||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||
|
||||
x *= self.strength
|
||||
if x.dtype != output_dtype:
|
||||
x = x.to(output_dtype)
|
||||
|
||||
out[key].append(x)
|
||||
|
||||
out = compute_controlnet_weighting(
|
||||
out,
|
||||
positive_advanced_weighting=self.positive_advanced_weighting,
|
||||
negative_advanced_weighting=self.negative_advanced_weighting,
|
||||
advanced_frame_weighting=self.advanced_frame_weighting,
|
||||
advanced_sigma_weighting=self.advanced_sigma_weighting,
|
||||
transformer_options=self.transformer_options
|
||||
)
|
||||
|
||||
if control_prev is not None:
|
||||
for x in ['input', 'middle', 'output']:
|
||||
o = out[x]
|
||||
for i in range(len(control_prev[x])):
|
||||
prev_val = control_prev[x][i]
|
||||
if i >= len(o):
|
||||
o.append(prev_val)
|
||||
elif prev_val is not None:
|
||||
if o[i] is None:
|
||||
o[i] = prev_val
|
||||
else:
|
||||
if o[i].shape[0] < prev_val.shape[0]:
|
||||
o[i] = prev_val + o[i]
|
||||
else:
|
||||
o[i] += prev_val
|
||||
return out
|
||||
|
||||
|
||||
def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
||||
out_cond = torch.zeros_like(x_in)
|
||||
out_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
out_uncond = torch.zeros_like(x_in)
|
||||
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
||||
|
||||
COND = 0
|
||||
UNCOND = 1
|
||||
|
||||
to_run = []
|
||||
for x in cond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, COND)]
|
||||
if uncond is not None:
|
||||
for x in uncond:
|
||||
p = get_area_and_mult(x, x_in, timestep)
|
||||
if p is None:
|
||||
continue
|
||||
|
||||
to_run += [(p, UNCOND)]
|
||||
|
||||
while len(to_run) > 0:
|
||||
first = to_run[0]
|
||||
first_shape = first[0][0].shape
|
||||
to_batch_temp = []
|
||||
for x in range(len(to_run)):
|
||||
if can_concat_cond(to_run[x][0], first[0]):
|
||||
to_batch_temp += [x]
|
||||
|
||||
to_batch_temp.reverse()
|
||||
to_batch = to_batch_temp[:1]
|
||||
|
||||
free_memory = model_management.get_free_memory(x_in.device)
|
||||
for i in range(1, len(to_batch_temp) + 1):
|
||||
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
|
||||
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
||||
if model.memory_required(input_shape) < free_memory:
|
||||
to_batch = batch_amount
|
||||
break
|
||||
|
||||
input_x = []
|
||||
mult = []
|
||||
c = []
|
||||
cond_or_uncond = []
|
||||
area = []
|
||||
control = None
|
||||
patches = None
|
||||
for x in to_batch:
|
||||
o = to_run.pop(x)
|
||||
p = o[0]
|
||||
input_x.append(p.input_x)
|
||||
mult.append(p.mult)
|
||||
c.append(p.conditioning)
|
||||
area.append(p.area)
|
||||
cond_or_uncond.append(o[1])
|
||||
control = p.control
|
||||
patches = p.patches
|
||||
|
||||
batch_chunks = len(cond_or_uncond)
|
||||
input_x = torch.cat(input_x)
|
||||
c = cond_cat(c)
|
||||
timestep_ = torch.cat([timestep] * batch_chunks)
|
||||
|
||||
transformer_options = {}
|
||||
if 'transformer_options' in model_options:
|
||||
transformer_options = model_options['transformer_options'].copy()
|
||||
|
||||
if patches is not None:
|
||||
if "patches" in transformer_options:
|
||||
cur_patches = transformer_options["patches"].copy()
|
||||
for p in patches:
|
||||
if p in cur_patches:
|
||||
cur_patches[p] = cur_patches[p] + patches[p]
|
||||
else:
|
||||
cur_patches[p] = patches[p]
|
||||
else:
|
||||
transformer_options["patches"] = patches
|
||||
|
||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||
transformer_options["sigmas"] = timestep
|
||||
|
||||
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
||||
|
||||
c['transformer_options'] = transformer_options
|
||||
|
||||
if control is not None:
|
||||
p = control
|
||||
while p is not None:
|
||||
p.transformer_options = transformer_options
|
||||
p = p.previous_controlnet
|
||||
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
|
||||
|
||||
if 'model_function_wrapper' in model_options:
|
||||
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
||||
else:
|
||||
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
||||
del input_x
|
||||
|
||||
for o in range(batch_chunks):
|
||||
if cond_or_uncond[o] == COND:
|
||||
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
else:
|
||||
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
||||
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
||||
del mult
|
||||
|
||||
out_cond /= out_count
|
||||
del out_count
|
||||
out_uncond /= out_uncond_count
|
||||
del out_uncond_count
|
||||
return out_cond, out_uncond
|
||||
|
||||
|
||||
def patched_load_models_gpu(*args, **kwargs):
|
||||
execution_start_time = time.perf_counter()
|
||||
y = model_management.load_models_gpu_origin(*args, **kwargs)
|
||||
moving_time = time.perf_counter() - execution_start_time
|
||||
if moving_time > 0.1:
|
||||
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
|
||||
return y
|
||||
|
||||
|
||||
def build_loaded(module, loader_name):
|
||||
@@ -233,14 +38,10 @@ def build_loaded(module, loader_name):
|
||||
|
||||
|
||||
def patch_all_basics():
|
||||
if not hasattr(model_management, 'load_models_gpu_origin'):
|
||||
model_management.load_models_gpu_origin = model_management.load_models_gpu
|
||||
|
||||
model_management.load_models_gpu = patched_load_models_gpu
|
||||
|
||||
ControlBase.control_merge = patched_control_merge
|
||||
ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch
|
||||
import ldm_patched.modules.controlnet
|
||||
import modules_forge.controlnet
|
||||
|
||||
ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting
|
||||
build_loaded(safetensors.torch, 'load_file')
|
||||
build_loaded(torch, 'load')
|
||||
return
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
# Consistent with Kohya/A1111 to reduce differences between model training and inference.
|
||||
|
||||
import os
|
||||
import torch
|
||||
import ldm_patched.controlnet.cldm
|
||||
import ldm_patched.k_diffusion.sampling
|
||||
import ldm_patched.ldm.modules.attention
|
||||
import ldm_patched.ldm.modules.diffusionmodules.model
|
||||
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm_patched.modules.args_parser
|
||||
import ldm_patched.modules.model_base
|
||||
import ldm_patched.modules.model_management
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.samplers
|
||||
import ldm_patched.modules.sd
|
||||
import ldm_patched.modules.sd1_clip
|
||||
import ldm_patched.modules.clip_vision
|
||||
import ldm_patched.modules.ops as ops
|
||||
|
||||
from modules_forge.ops import use_patched_ops
|
||||
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
|
||||
|
||||
|
||||
def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None,
|
||||
textmodel_json_config=None, dtype=None, special_tokens=None,
|
||||
layer_norm_hidden_state=True, **kwargs):
|
||||
torch.nn.Module.__init__(self)
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)),
|
||||
"sd1_clip_config.json")
|
||||
|
||||
config = CLIPTextConfig.from_json_file(textmodel_json_config)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
with use_patched_ops(ops.manual_cast):
|
||||
with modeling_utils.no_init_weights():
|
||||
self.transformer = CLIPTextModel(config)
|
||||
|
||||
if dtype is not None:
|
||||
self.transformer.to(dtype)
|
||||
|
||||
self.transformer.text_model.embeddings.to(torch.float32)
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
self.max_length = max_length
|
||||
self.layer = layer
|
||||
self.layer_idx = None
|
||||
self.special_tokens = special_tokens
|
||||
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
|
||||
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
|
||||
self.enable_attention_masks = False
|
||||
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert abs(layer_idx) < self.num_layers
|
||||
self.clip_layer(layer_idx)
|
||||
self.layer_default = (self.layer, self.layer_idx)
|
||||
|
||||
|
||||
def patched_SDClipModel_forward(self, tokens):
|
||||
backup_embeds = self.transformer.get_input_embeddings()
|
||||
device = backup_embeds.weight.device
|
||||
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
|
||||
tokens = torch.LongTensor(tokens).to(device)
|
||||
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == max_token:
|
||||
break
|
||||
|
||||
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
|
||||
output_hidden_states=self.layer == "hidden")
|
||||
self.transformer.set_input_embeddings(backup_embeds)
|
||||
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
if self.layer_norm_hidden_state:
|
||||
z = self.transformer.text_model.final_layer_norm(z)
|
||||
|
||||
if hasattr(outputs, "pooler_output"):
|
||||
pooled_output = outputs.pooler_output.float()
|
||||
else:
|
||||
pooled_output = None
|
||||
|
||||
if self.text_projection is not None and pooled_output is not None:
|
||||
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
|
||||
|
||||
return z.float(), pooled_output
|
||||
|
||||
|
||||
def patch_all_clip():
|
||||
ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__
|
||||
ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward
|
||||
return
|
||||
@@ -1,60 +0,0 @@
|
||||
# Consistent with Kohya to reduce differences between model training and inference.
|
||||
|
||||
import torch
|
||||
import math
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
|
||||
import ldm_patched.modules.model_sampling
|
||||
import ldm_patched.modules.sd1_clip
|
||||
|
||||
from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
|
||||
def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
# Consistent with Kohya to reduce differences between model training and inference.
|
||||
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = einops.repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
# Consistent with Kohya to reduce differences between model training and inference.
|
||||
|
||||
if given_betas is not None:
|
||||
betas = given_betas
|
||||
else:
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
||||
self.set_sigmas(sigmas)
|
||||
return
|
||||
|
||||
|
||||
def patch_all_precision():
|
||||
ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding
|
||||
ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule
|
||||
return
|
||||
@@ -1,7 +1,5 @@
|
||||
import copy
|
||||
import torch
|
||||
|
||||
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed, apply_control
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@@ -76,111 +74,3 @@ class UnetPatcher(ModelPatcher):
|
||||
for transformer_index in range(16):
|
||||
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
|
||||
return
|
||||
|
||||
|
||||
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
|
||||
transformer_options["original_shape"] = list(x.shape)
|
||||
transformer_options["transformer_index"] = 0
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
block_modifiers = transformer_options.get("block_modifiers", [])
|
||||
|
||||
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
||||
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
|
||||
time_context = kwargs.get("time_context", None)
|
||||
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
transformer_options["block"] = ("input", id)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context,
|
||||
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'input')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
if "input_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
hs.append(h)
|
||||
if "input_block_patch_after_skip" in transformer_patches:
|
||||
patch = transformer_patches["input_block_patch_after_skip"]
|
||||
for p in patch:
|
||||
h = p(h, transformer_options)
|
||||
|
||||
transformer_options["block"] = ("middle", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context,
|
||||
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||
h = apply_control(h, control, 'middle')
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
for id, module in enumerate(self.output_blocks):
|
||||
transformer_options["block"] = ("output", id)
|
||||
hsp = hs.pop()
|
||||
hsp = apply_control(hsp, control, 'output')
|
||||
|
||||
if "output_block_patch" in transformer_patches:
|
||||
patch = transformer_patches["output_block_patch"]
|
||||
for p in patch:
|
||||
h, hsp = p(h, hsp, transformer_options)
|
||||
|
||||
h = torch.cat([h, hsp], dim=1)
|
||||
del hsp
|
||||
if len(hs) > 0:
|
||||
output_shape = hs[-1].shape
|
||||
else:
|
||||
output_shape = None
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape,
|
||||
time_context=time_context, num_video_frames=num_video_frames,
|
||||
image_only_indicator=image_only_indicator)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
transformer_options["block"] = ("last", 0)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'before', transformer_options)
|
||||
|
||||
if self.predict_codebook_ids:
|
||||
h = self.id_predictor(h)
|
||||
else:
|
||||
h = self.out(h)
|
||||
|
||||
for block_modifier in block_modifiers:
|
||||
h = block_modifier(h, 'after', transformer_options)
|
||||
|
||||
return h.type(x.dtype)
|
||||
|
||||
|
||||
def patch_all():
|
||||
UNetModel.forward = forge_unet_forward
|
||||
|
||||
Reference in New Issue
Block a user