This commit is contained in:
lllyasviel
2024-01-31 09:46:24 -08:00
parent 8a0dcd09c9
commit 071be046d2
15 changed files with 153 additions and 561 deletions

View File

@@ -825,6 +825,7 @@ class UNetModel(nn.Module):
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
block_modifiers = transformer_options.get("block_modifiers", [])
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
@@ -844,8 +845,16 @@ class UNetModel(nn.Module):
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
@@ -858,9 +867,15 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
@@ -878,9 +893,26 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if self.predict_codebook_ids:
return self.id_predictor(h)
h = self.id_predictor(h)
else:
return self.out(h)
h = self.out(h)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
return h.type(x.dtype)

View File

@@ -225,19 +225,13 @@ class CheckpointFunction(torch.autograd.Function):
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
# Consistent with Kohya to reduce differences between model training and inference.
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:

View File

@@ -11,6 +11,9 @@ import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
compute_controlnet_weighting = None
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
@@ -114,6 +117,10 @@ class ControlBase:
x = x.to(output_dtype)
out[key].append(x)
if compute_controlnet_weighting is not None:
out = compute_controlnet_weighting(out, self)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]

View File

@@ -1,3 +1,4 @@
import time
import psutil
from enum import Enum
from ldm_patched.modules.args_parser import args
@@ -42,8 +43,6 @@ if args.directml is not None:
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
@@ -128,6 +127,9 @@ try:
except:
OOM_EXCEPTION = Exception
if directml_enabled:
OOM_EXCEPTION = Exception
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
@@ -376,6 +378,8 @@ def free_memory(memory_required, device, keep_loaded=[]):
def load_models_gpu(models, memory_required=0):
global vram_state
execution_start_time = time.perf_counter()
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
@@ -390,7 +394,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
print(f"To load target model {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@@ -398,9 +402,14 @@ def load_models_gpu(models, memory_required=0):
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) skipped. Freeing memory has taken {moving_time:.2f} seconds')
return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
@@ -433,6 +442,11 @@ def load_models_gpu(models, memory_required=0):
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
return

View File

@@ -1,5 +1,24 @@
import torch
import ldm_patched.modules.model_management
import contextlib
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
def cast_bias_weight(s, input):
bias = None

View File

@@ -126,6 +126,29 @@ def cond_cat(c_list):
return out
def compute_cond_mark(cond_or_uncond, sigmas):
cond_or_uncond_size = int(sigmas.shape[0])
cond_mark = []
for cx in cond_or_uncond:
cond_mark += [cx] * cond_or_uncond_size
cond_mark = torch.Tensor(cond_mark).to(sigmas)
return cond_mark
def compute_cond_indices(cond_or_uncond, sigmas):
cl = int(sigmas.shape[0])
cond_indices = []
uncond_indices = []
for i, cx in enumerate(cond_or_uncond):
if cx == 0:
cond_indices += list(range(i * cl, (i + 1) * cl))
else:
uncond_indices += list(range(i * cl, (i + 1) * cl))
return cond_indices, uncond_indices
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
@@ -193,9 +216,6 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
@@ -214,8 +234,18 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
c['transformer_options'] = transformer_options
if control is not None:
p = control
while p is not None:
p.transformer_options = transformer_options
p = p.previous_controlnet
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:

View File

@@ -8,6 +8,7 @@ import zipfile
from . import model_management
import ldm_patched.modules.clip_model
import json
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@@ -74,11 +75,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f:
config = json.load(f)
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
self.transformer = model_class(config, dtype, device, ldm_patched.modules.ops.manual_cast)
self.num_layers = self.transformer.num_layers
with ldm_patched.modules.ops.use_patched_ops(ldm_patched.modules.ops.manual_cast):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.to(torch.float32)
self.max_length = max_length
if freeze:
@@ -169,16 +176,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs[1]
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
if outputs[2] is not None:
pooled_output = outputs[2].float()
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None

View File

@@ -33,19 +33,13 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.getcwd(), "output")
temp_directory = os.path.join(os.getcwd(), "temp")
input_directory = os.path.join(os.getcwd(), "input")
user_directory = os.path.join(os.getcwd(), "user")
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")
user_directory = os.path.join(base_path, "user")
filename_list_cache = {}
if not os.path.exists(input_directory):
try:
pass # os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir

View File

@@ -66,14 +66,14 @@ def apply_controlnet_advanced(
return m
def compute_controlnet_weighting(
control,
positive_advanced_weighting,
negative_advanced_weighting,
advanced_frame_weighting,
advanced_sigma_weighting,
transformer_options
):
def compute_controlnet_weighting(control, cnet):
positive_advanced_weighting = cnet.positive_advanced_weighting
negative_advanced_weighting = cnet.negative_advanced_weighting
advanced_frame_weighting = cnet.advanced_frame_weighting
advanced_sigma_weighting = cnet.advanced_sigma_weighting
transformer_options = cnet.transformer_options
if positive_advanced_weighting is None and negative_advanced_weighting is None \
and advanced_frame_weighting is None and advanced_sigma_weighting is None:
return control

View File

@@ -24,22 +24,9 @@ def initialize_forge():
torch.zeros((1, 1)).to(device, torch.float32)
model_management.soft_empty_cache()
import modules_forge.patch_clip
modules_forge.patch_clip.patch_all_clip()
import modules_forge.patch_precision
modules_forge.patch_precision.patch_all_precision()
import modules_forge.patch_basic
modules_forge.patch_basic.patch_all_basics()
import modules_forge.unet_patcher
modules_forge.unet_patcher.patch_all()
if model_management.directml_enabled:
model_management.lowvram_available = True
model_management.OOM_EXCEPTION = Exception
from modules_forge import supported_preprocessor
from modules_forge import supported_controlnet

View File

@@ -2,23 +2,7 @@ import time
import torch
import contextlib
from ldm_patched.modules import model_management
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
from ldm_patched.modules.ops import use_patched_ops
@contextlib.contextmanager

View File

@@ -1,201 +1,6 @@
import torch
import os
import time
import safetensors
import ldm_patched.modules.samplers
from ldm_patched.modules.controlnet import ControlBase
from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat
from ldm_patched.modules import model_management
from modules_forge.controlnet import compute_controlnet_weighting
from modules_forge.forge_util import compute_cond_mark, compute_cond_indices
def patched_control_merge(self, control_input, control_output, control_prev, output_dtype):
out = {'input': [], 'middle': [], 'output': []}
if control_input is not None:
for i in range(len(control_input)):
key = 'input'
x = control_input[i]
if x is not None:
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].insert(0, x)
if control_output is not None:
for i in range(len(control_output)):
if i == (len(control_output) - 1):
key = 'middle'
index = 0
else:
key = 'output'
index = i
x = control_output[i]
if x is not None:
if self.global_average_pooling:
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
x *= self.strength
if x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
out = compute_controlnet_weighting(
out,
positive_advanced_weighting=self.positive_advanced_weighting,
negative_advanced_weighting=self.negative_advanced_weighting,
advanced_frame_weighting=self.advanced_frame_weighting,
advanced_sigma_weighting=self.advanced_sigma_weighting,
transformer_options=self.transformer_options
)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]
for i in range(len(control_prev[x])):
prev_val = control_prev[x][i]
if i >= len(o):
o.append(prev_val)
elif prev_val is not None:
if o[i] is None:
o[i] = prev_val
else:
if o[i].shape[0] < prev_val.shape[0]:
o[i] = prev_val + o[i]
else:
o[i] += prev_val
return out
def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
out_uncond = torch.zeros_like(x_in)
out_uncond_count = torch.ones_like(x_in) * 1e-37
COND = 0
UNCOND = 1
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
if uncond is not None:
for x in uncond:
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
to_run += [(p, UNCOND)]
while len(to_run) > 0:
first = to_run[0]
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = model_management.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break
input_x = []
mult = []
c = []
cond_or_uncond = []
area = []
control = None
patches = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
if patches is not None:
if "patches" in transformer_options:
cur_patches = transformer_options["patches"].copy()
for p in patches:
if p in cur_patches:
cur_patches[p] = cur_patches[p] + patches[p]
else:
cur_patches[p] = patches[p]
else:
transformer_options["patches"] = patches
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
c['transformer_options'] = transformer_options
if control is not None:
p = control
while p is not None:
p.transformer_options = transformer_options
p = p.previous_controlnet
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
if cond_or_uncond[o] == COND:
out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
else:
out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o]
del mult
out_cond /= out_count
del out_count
out_uncond /= out_uncond_count
del out_uncond_count
return out_cond, out_uncond
def patched_load_models_gpu(*args, **kwargs):
execution_start_time = time.perf_counter()
y = model_management.load_models_gpu_origin(*args, **kwargs)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
return y
def build_loaded(module, loader_name):
@@ -233,14 +38,10 @@ def build_loaded(module, loader_name):
def patch_all_basics():
if not hasattr(model_management, 'load_models_gpu_origin'):
model_management.load_models_gpu_origin = model_management.load_models_gpu
model_management.load_models_gpu = patched_load_models_gpu
ControlBase.control_merge = patched_control_merge
ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch
import ldm_patched.modules.controlnet
import modules_forge.controlnet
ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting
build_loaded(safetensors.torch, 'load_file')
build_loaded(torch, 'load')
return

View File

@@ -1,112 +0,0 @@
# Consistent with Kohya/A1111 to reduce differences between model training and inference.
import os
import torch
import ldm_patched.controlnet.cldm
import ldm_patched.k_diffusion.sampling
import ldm_patched.ldm.modules.attention
import ldm_patched.ldm.modules.diffusionmodules.model
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.modules.args_parser
import ldm_patched.modules.model_base
import ldm_patched.modules.model_management
import ldm_patched.modules.model_patcher
import ldm_patched.modules.samplers
import ldm_patched.modules.sd
import ldm_patched.modules.sd1_clip
import ldm_patched.modules.clip_vision
import ldm_patched.modules.ops as ops
from modules_forge.ops import use_patched_ops
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None,
textmodel_json_config=None, dtype=None, special_tokens=None,
layer_norm_hidden_state=True, **kwargs):
torch.nn.Module.__init__(self)
assert layer in self.LAYERS
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)),
"sd1_clip_config.json")
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
with use_patched_ops(ops.manual_cast):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.to(torch.float32)
if freeze:
self.freeze()
self.max_length = max_length
self.layer = layer
self.layer_idx = None
self.special_tokens = special_tokens
self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1]))
self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055))
self.enable_attention_masks = False
self.layer_norm_hidden_state = layer_norm_hidden_state
if layer == "hidden":
assert layer_idx is not None
assert abs(layer_idx) < self.num_layers
self.clip_layer(layer_idx)
self.layer_default = (self.layer, self.layer_idx)
def patched_SDClipModel_forward(self, tokens):
backup_embeds = self.transformer.get_input_embeddings()
device = backup_embeds.weight.device
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output
def patch_all_clip():
ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__
ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward
return

View File

@@ -1,60 +0,0 @@
# Consistent with Kohya to reduce differences between model training and inference.
import torch
import math
import einops
import numpy as np
import ldm_patched.ldm.modules.diffusionmodules.openaimodel
import ldm_patched.modules.model_sampling
import ldm_patched.modules.sd1_clip
from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule
def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
# Consistent with Kohya to reduce differences between model training and inference.
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = einops.repeat(timesteps, 'b -> b d', d=dim)
return embedding
def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
# Consistent with Kohya to reduce differences between model training and inference.
if given_betas is not None:
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
return
def patch_all_precision():
ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding
ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule
return

View File

@@ -1,7 +1,5 @@
import copy
import torch
from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed, apply_control
from ldm_patched.modules.model_patcher import ModelPatcher
@@ -76,111 +74,3 @@ class UnetPatcher(ModelPatcher):
for transformer_index in range(16):
self.set_model_patch_replace(patch, target, block_name, number, transformer_index)
return
def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
block_modifiers = transformer_options.get("block_modifiers", [])
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context,
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)
hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context,
num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control(hsp, control, 'output')
if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)
h = torch.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
output_shape = hs[-1].shape
else:
output_shape = None
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape,
time_context=time_context, num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if self.predict_codebook_ids:
h = self.id_predictor(h)
else:
h = self.out(h)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
return h.type(x.dtype)
def patch_all():
UNetModel.forward = forge_unet_forward