diff --git a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py index 4b695f76..ffd168af 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm_patched/ldm/modules/diffusionmodules/openaimodel.py @@ -825,6 +825,7 @@ class UNetModel(nn.Module): transformer_options["original_shape"] = list(x.shape) transformer_options["transformer_index"] = 0 transformer_patches = transformer_options.get("patches", {}) + block_modifiers = transformer_options.get("block_modifiers", []) num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) @@ -844,8 +845,16 @@ class UNetModel(nn.Module): h = x for id, module in enumerate(self.input_blocks): transformer_options["block"] = ("input", id) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'input') + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + if "input_block_patch" in transformer_patches: patch = transformer_patches["input_block_patch"] for p in patch: @@ -858,9 +867,15 @@ class UNetModel(nn.Module): h = p(h, transformer_options) transformer_options["block"] = ("middle", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) @@ -878,9 +893,26 @@ class UNetModel(nn.Module): output_shape = hs[-1].shape else: output_shape = None + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) - h = h.type(x.dtype) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + transformer_options["block"] = ("last", 0) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'before', transformer_options) + if self.predict_codebook_ids: - return self.id_predictor(h) + h = self.id_predictor(h) else: - return self.out(h) + h = self.out(h) + + for block_modifier in block_modifiers: + h = block_modifier(h, 'after', transformer_options) + + return h.type(x.dtype) diff --git a/ldm_patched/ldm/modules/diffusionmodules/util.py b/ldm_patched/ldm/modules/diffusionmodules/util.py index e261e06a..eeef837e 100644 --- a/ldm_patched/ldm/modules/diffusionmodules/util.py +++ b/ldm_patched/ldm/modules/diffusionmodules/util.py @@ -225,19 +225,13 @@ class CheckpointFunction(torch.autograd.Function): def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ - Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ + # Consistent with Kohya to reduce differences between model training and inference. + if not repeat_only: half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half - ) + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: diff --git a/ldm_patched/modules/controlnet.py b/ldm_patched/modules/controlnet.py index 7e11497f..67d9c9ee 100644 --- a/ldm_patched/modules/controlnet.py +++ b/ldm_patched/modules/controlnet.py @@ -11,6 +11,9 @@ import ldm_patched.controlnet.cldm import ldm_patched.t2ia.adapter +compute_controlnet_weighting = None + + def broadcast_image_to(tensor, target_batch_size, batched_number): current_batch_size = tensor.shape[0] #print(current_batch_size, target_batch_size) @@ -114,6 +117,10 @@ class ControlBase: x = x.to(output_dtype) out[key].append(x) + + if compute_controlnet_weighting is not None: + out = compute_controlnet_weighting(out, self) + if control_prev is not None: for x in ['input', 'middle', 'output']: o = out[x] diff --git a/ldm_patched/modules/model_management.py b/ldm_patched/modules/model_management.py index 6f88579d..c9d9f52f 100644 --- a/ldm_patched/modules/model_management.py +++ b/ldm_patched/modules/model_management.py @@ -1,3 +1,4 @@ +import time import psutil from enum import Enum from ldm_patched.modules.args_parser import args @@ -42,8 +43,6 @@ if args.directml is not None: else: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) - # torch_directml.disable_tiled_resources(True) - lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: import intel_extension_for_pytorch as ipex @@ -128,6 +127,9 @@ try: except: OOM_EXCEPTION = Exception +if directml_enabled: + OOM_EXCEPTION = Exception + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: @@ -376,6 +378,8 @@ def free_memory(memory_required, device, keep_loaded=[]): def load_models_gpu(models, memory_required=0): global vram_state + execution_start_time = time.perf_counter() + inference_memory = minimum_inference_memory() extra_mem = max(inference_memory, memory_required) @@ -390,7 +394,7 @@ def load_models_gpu(models, memory_required=0): models_already_loaded.append(loaded_model) else: if hasattr(x, "model"): - print(f"Requested to load {x.model.__class__.__name__}") + print(f"To load target model {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -398,9 +402,14 @@ def load_models_gpu(models, memory_required=0): for d in devs: if d != torch.device("cpu"): free_memory(extra_mem, d, models_already_loaded) + + moving_time = time.perf_counter() - execution_start_time + if moving_time > 0.1: + print(f'Moving model(s) skipped. Freeing memory has taken {moving_time:.2f} seconds') + return - print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") + print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: @@ -433,6 +442,11 @@ def load_models_gpu(models, memory_required=0): cur_loaded_model = loaded_model.model_load(lowvram_model_memory) current_loaded_models.insert(0, loaded_model) + + moving_time = time.perf_counter() - execution_start_time + if moving_time > 0.1: + print(f'Moving model(s) has taken {moving_time:.2f} seconds') + return diff --git a/ldm_patched/modules/ops.py b/ldm_patched/modules/ops.py index 2d7fa377..c9926fd2 100644 --- a/ldm_patched/modules/ops.py +++ b/ldm_patched/modules/ops.py @@ -1,5 +1,24 @@ import torch import ldm_patched.modules.model_management +import contextlib + + +@contextlib.contextmanager +def use_patched_ops(operations): + op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] + backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} + + try: + for op_name in op_names: + setattr(torch.nn, op_name, getattr(operations, op_name)) + + yield + + finally: + for op_name in op_names: + setattr(torch.nn, op_name, backups[op_name]) + return + def cast_bias_weight(s, input): bias = None diff --git a/ldm_patched/modules/samplers.py b/ldm_patched/modules/samplers.py index 1f69d2b1..d6f36b21 100644 --- a/ldm_patched/modules/samplers.py +++ b/ldm_patched/modules/samplers.py @@ -126,6 +126,29 @@ def cond_cat(c_list): return out +def compute_cond_mark(cond_or_uncond, sigmas): + cond_or_uncond_size = int(sigmas.shape[0]) + + cond_mark = [] + for cx in cond_or_uncond: + cond_mark += [cx] * cond_or_uncond_size + + cond_mark = torch.Tensor(cond_mark).to(sigmas) + return cond_mark + +def compute_cond_indices(cond_or_uncond, sigmas): + cl = int(sigmas.shape[0]) + + cond_indices = [] + uncond_indices = [] + for i, cx in enumerate(cond_or_uncond): + if cx == 0: + cond_indices += list(range(i * cl, (i + 1) * cl)) + else: + uncond_indices += list(range(i * cl, (i + 1) * cl)) + + return cond_indices, uncond_indices + def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): out_cond = torch.zeros_like(x_in) out_count = torch.ones_like(x_in) * 1e-37 @@ -193,9 +216,6 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): c = cond_cat(c) timestep_ = torch.cat([timestep] * batch_chunks) - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - transformer_options = {} if 'transformer_options' in model_options: transformer_options = model_options['transformer_options'].copy() @@ -214,8 +234,18 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): transformer_options["cond_or_uncond"] = cond_or_uncond[:] transformer_options["sigmas"] = timestep + transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) + transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep) + c['transformer_options'] = transformer_options + if control is not None: + p = control + while p is not None: + p.transformer_options = transformer_options + p = p.previous_controlnet + c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) + if 'model_function_wrapper' in model_options: output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) else: diff --git a/ldm_patched/modules/sd1_clip.py b/ldm_patched/modules/sd1_clip.py index 3727fb48..a1cdec2e 100644 --- a/ldm_patched/modules/sd1_clip.py +++ b/ldm_patched/modules/sd1_clip.py @@ -8,6 +8,7 @@ import zipfile from . import model_management import ldm_patched.modules.clip_model import json +from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) @@ -74,11 +75,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if textmodel_json_config is None: textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") - with open(textmodel_json_config) as f: - config = json.load(f) + config = CLIPTextConfig.from_json_file(textmodel_json_config) + self.num_layers = config.num_hidden_layers - self.transformer = model_class(config, dtype, device, ldm_patched.modules.ops.manual_cast) - self.num_layers = self.transformer.num_layers + with ldm_patched.modules.ops.use_patched_ops(ldm_patched.modules.ops.manual_cast): + with modeling_utils.no_init_weights(): + self.transformer = CLIPTextModel(config) + + if dtype is not None: + self.transformer.to(dtype) + + self.transformer.text_model.embeddings.to(torch.float32) self.max_length = max_length if freeze: @@ -169,16 +176,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if tokens[x, y] == max_token: break - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, + output_hidden_states=self.layer == "hidden") self.transformer.set_input_embeddings(backup_embeds) if self.layer == "last": - z = outputs[0] + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] else: - z = outputs[1] + z = outputs.hidden_states[self.layer_idx] + if self.layer_norm_hidden_state: + z = self.transformer.text_model.final_layer_norm(z) - if outputs[2] is not None: - pooled_output = outputs[2].float() + if hasattr(outputs, "pooler_output"): + pooled_output = outputs.pooler_output.float() else: pooled_output = None diff --git a/ldm_patched/utils/path_utils.py b/ldm_patched/utils/path_utils.py index 6cae149b..a2643f43 100644 --- a/ldm_patched/utils/path_utils.py +++ b/ldm_patched/utils/path_utils.py @@ -33,19 +33,13 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")] folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) -output_directory = os.path.join(os.getcwd(), "output") -temp_directory = os.path.join(os.getcwd(), "temp") -input_directory = os.path.join(os.getcwd(), "input") -user_directory = os.path.join(os.getcwd(), "user") +output_directory = os.path.join(base_path, "output") +temp_directory = os.path.join(base_path, "temp") +input_directory = os.path.join(base_path, "input") +user_directory = os.path.join(base_path, "user") filename_list_cache = {} -if not os.path.exists(input_directory): - try: - pass # os.makedirs(input_directory) - except: - print("Failed to create input directory") - def set_output_directory(output_dir): global output_directory output_directory = output_dir diff --git a/modules_forge/controlnet.py b/modules_forge/controlnet.py index 24371de9..8ad88311 100644 --- a/modules_forge/controlnet.py +++ b/modules_forge/controlnet.py @@ -66,14 +66,14 @@ def apply_controlnet_advanced( return m -def compute_controlnet_weighting( - control, - positive_advanced_weighting, - negative_advanced_weighting, - advanced_frame_weighting, - advanced_sigma_weighting, - transformer_options -): +def compute_controlnet_weighting(control, cnet): + + positive_advanced_weighting = cnet.positive_advanced_weighting + negative_advanced_weighting = cnet.negative_advanced_weighting + advanced_frame_weighting = cnet.advanced_frame_weighting + advanced_sigma_weighting = cnet.advanced_sigma_weighting + transformer_options = cnet.transformer_options + if positive_advanced_weighting is None and negative_advanced_weighting is None \ and advanced_frame_weighting is None and advanced_sigma_weighting is None: return control diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 415f33af..0cd86531 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -24,22 +24,9 @@ def initialize_forge(): torch.zeros((1, 1)).to(device, torch.float32) model_management.soft_empty_cache() - import modules_forge.patch_clip - modules_forge.patch_clip.patch_all_clip() - - import modules_forge.patch_precision - modules_forge.patch_precision.patch_all_precision() - import modules_forge.patch_basic modules_forge.patch_basic.patch_all_basics() - import modules_forge.unet_patcher - modules_forge.unet_patcher.patch_all() - - if model_management.directml_enabled: - model_management.lowvram_available = True - model_management.OOM_EXCEPTION = Exception - from modules_forge import supported_preprocessor from modules_forge import supported_controlnet diff --git a/modules_forge/ops.py b/modules_forge/ops.py index 2a182e36..a0a0d171 100644 --- a/modules_forge/ops.py +++ b/modules_forge/ops.py @@ -2,23 +2,7 @@ import time import torch import contextlib from ldm_patched.modules import model_management - - -@contextlib.contextmanager -def use_patched_ops(operations): - op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm'] - backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names} - - try: - for op_name in op_names: - setattr(torch.nn, op_name, getattr(operations, op_name)) - - yield - - finally: - for op_name in op_names: - setattr(torch.nn, op_name, backups[op_name]) - return +from ldm_patched.modules.ops import use_patched_ops @contextlib.contextmanager diff --git a/modules_forge/patch_basic.py b/modules_forge/patch_basic.py index 64067ec3..eab64573 100644 --- a/modules_forge/patch_basic.py +++ b/modules_forge/patch_basic.py @@ -1,201 +1,6 @@ import torch import os -import time import safetensors -import ldm_patched.modules.samplers - -from ldm_patched.modules.controlnet import ControlBase -from ldm_patched.modules.samplers import get_area_and_mult, can_concat_cond, cond_cat -from ldm_patched.modules import model_management -from modules_forge.controlnet import compute_controlnet_weighting -from modules_forge.forge_util import compute_cond_mark, compute_cond_indices - - -def patched_control_merge(self, control_input, control_output, control_prev, output_dtype): - out = {'input': [], 'middle': [], 'output': []} - - if control_input is not None: - for i in range(len(control_input)): - key = 'input' - x = control_input[i] - if x is not None: - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - out[key].insert(0, x) - - if control_output is not None: - for i in range(len(control_output)): - if i == (len(control_output) - 1): - key = 'middle' - index = 0 - else: - key = 'output' - index = i - x = control_output[i] - if x is not None: - if self.global_average_pooling: - x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) - - x *= self.strength - if x.dtype != output_dtype: - x = x.to(output_dtype) - - out[key].append(x) - - out = compute_controlnet_weighting( - out, - positive_advanced_weighting=self.positive_advanced_weighting, - negative_advanced_weighting=self.negative_advanced_weighting, - advanced_frame_weighting=self.advanced_frame_weighting, - advanced_sigma_weighting=self.advanced_sigma_weighting, - transformer_options=self.transformer_options - ) - - if control_prev is not None: - for x in ['input', 'middle', 'output']: - o = out[x] - for i in range(len(control_prev[x])): - prev_val = control_prev[x][i] - if i >= len(o): - o.append(prev_val) - elif prev_val is not None: - if o[i] is None: - o[i] = prev_val - else: - if o[i].shape[0] < prev_val.shape[0]: - o[i] = prev_val + o[i] - else: - o[i] += prev_val - return out - - -def patched_calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): - out_cond = torch.zeros_like(x_in) - out_count = torch.ones_like(x_in) * 1e-37 - - out_uncond = torch.zeros_like(x_in) - out_uncond_count = torch.ones_like(x_in) * 1e-37 - - COND = 0 - UNCOND = 1 - - to_run = [] - for x in cond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, COND)] - if uncond is not None: - for x in uncond: - p = get_area_and_mult(x, x_in, timestep) - if p is None: - continue - - to_run += [(p, UNCOND)] - - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - control = p.control - patches = p.patches - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep) - transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep) - - c['transformer_options'] = transformer_options - - if control is not None: - p = control - while p is not None: - p.transformer_options = transformer_options - p = p.previous_controlnet - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - del input_x - - for o in range(batch_chunks): - if cond_or_uncond[o] == COND: - out_cond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - else: - out_uncond[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_uncond_count[:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - del mult - - out_cond /= out_count - del out_count - out_uncond /= out_uncond_count - del out_uncond_count - return out_cond, out_uncond - - -def patched_load_models_gpu(*args, **kwargs): - execution_start_time = time.perf_counter() - y = model_management.load_models_gpu_origin(*args, **kwargs) - moving_time = time.perf_counter() - execution_start_time - if moving_time > 0.1: - print(f'Moving model(s) has taken {moving_time:.2f} seconds') - return y def build_loaded(module, loader_name): @@ -233,14 +38,10 @@ def build_loaded(module, loader_name): def patch_all_basics(): - if not hasattr(model_management, 'load_models_gpu_origin'): - model_management.load_models_gpu_origin = model_management.load_models_gpu - - model_management.load_models_gpu = patched_load_models_gpu - - ControlBase.control_merge = patched_control_merge - ldm_patched.modules.samplers.calc_cond_uncond_batch = patched_calc_cond_uncond_batch + import ldm_patched.modules.controlnet + import modules_forge.controlnet + ldm_patched.modules.controlnet.compute_controlnet_weighting = modules_forge.controlnet.compute_controlnet_weighting build_loaded(safetensors.torch, 'load_file') build_loaded(torch, 'load') return diff --git a/modules_forge/patch_clip.py b/modules_forge/patch_clip.py deleted file mode 100644 index 3cdbdac7..00000000 --- a/modules_forge/patch_clip.py +++ /dev/null @@ -1,112 +0,0 @@ -# Consistent with Kohya/A1111 to reduce differences between model training and inference. - -import os -import torch -import ldm_patched.controlnet.cldm -import ldm_patched.k_diffusion.sampling -import ldm_patched.ldm.modules.attention -import ldm_patched.ldm.modules.diffusionmodules.model -import ldm_patched.ldm.modules.diffusionmodules.openaimodel -import ldm_patched.ldm.modules.diffusionmodules.openaimodel -import ldm_patched.modules.args_parser -import ldm_patched.modules.model_base -import ldm_patched.modules.model_management -import ldm_patched.modules.model_patcher -import ldm_patched.modules.samplers -import ldm_patched.modules.sd -import ldm_patched.modules.sd1_clip -import ldm_patched.modules.clip_vision -import ldm_patched.modules.ops as ops - -from modules_forge.ops import use_patched_ops -from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils - - -def patched_SDClipModel__init__(self, max_length=77, freeze=True, layer="last", layer_idx=None, - textmodel_json_config=None, dtype=None, special_tokens=None, - layer_norm_hidden_state=True, **kwargs): - torch.nn.Module.__init__(self) - assert layer in self.LAYERS - - if special_tokens is None: - special_tokens = {"start": 49406, "end": 49407, "pad": 49407} - - if textmodel_json_config is None: - textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(ldm_patched.modules.sd1_clip.__file__)), - "sd1_clip_config.json") - - config = CLIPTextConfig.from_json_file(textmodel_json_config) - self.num_layers = config.num_hidden_layers - - with use_patched_ops(ops.manual_cast): - with modeling_utils.no_init_weights(): - self.transformer = CLIPTextModel(config) - - if dtype is not None: - self.transformer.to(dtype) - - self.transformer.text_model.embeddings.to(torch.float32) - - if freeze: - self.freeze() - - self.max_length = max_length - self.layer = layer - self.layer_idx = None - self.special_tokens = special_tokens - self.text_projection = torch.nn.Parameter(torch.eye(self.transformer.get_input_embeddings().weight.shape[1])) - self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) - self.enable_attention_masks = False - - self.layer_norm_hidden_state = layer_norm_hidden_state - if layer == "hidden": - assert layer_idx is not None - assert abs(layer_idx) < self.num_layers - self.clip_layer(layer_idx) - self.layer_default = (self.layer, self.layer_idx) - - -def patched_SDClipModel_forward(self, tokens): - backup_embeds = self.transformer.get_input_embeddings() - device = backup_embeds.weight.device - tokens = self.set_up_textual_embeddings(tokens, backup_embeds) - tokens = torch.LongTensor(tokens).to(device) - - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break - - outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask, - output_hidden_states=self.layer == "hidden") - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs.last_hidden_state - elif self.layer == "pooled": - z = outputs.pooler_output[:, None, :] - else: - z = outputs.hidden_states[self.layer_idx] - if self.layer_norm_hidden_state: - z = self.transformer.text_model.final_layer_norm(z) - - if hasattr(outputs, "pooler_output"): - pooled_output = outputs.pooler_output.float() - else: - pooled_output = None - - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() - - return z.float(), pooled_output - - -def patch_all_clip(): - ldm_patched.modules.sd1_clip.SDClipModel.__init__ = patched_SDClipModel__init__ - ldm_patched.modules.sd1_clip.SDClipModel.forward = patched_SDClipModel_forward - return diff --git a/modules_forge/patch_precision.py b/modules_forge/patch_precision.py deleted file mode 100644 index 83569bdd..00000000 --- a/modules_forge/patch_precision.py +++ /dev/null @@ -1,60 +0,0 @@ -# Consistent with Kohya to reduce differences between model training and inference. - -import torch -import math -import einops -import numpy as np - -import ldm_patched.ldm.modules.diffusionmodules.openaimodel -import ldm_patched.modules.model_sampling -import ldm_patched.modules.sd1_clip - -from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule - - -def patched_timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - # Consistent with Kohya to reduce differences between model training and inference. - - if not repeat_only: - half = dim // 2 - freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - else: - embedding = einops.repeat(timesteps, 'b -> b d', d=dim) - return embedding - - -def patched_register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, - linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): - # Consistent with Kohya to reduce differences between model training and inference. - - if given_betas is not None: - betas = given_betas - else: - betas = make_beta_schedule( - beta_schedule, - timesteps, - linear_start=linear_start, - linear_end=linear_end, - cosine_s=cosine_s) - - alphas = 1. - betas - alphas_cumprod = np.cumprod(alphas, axis=0) - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.linear_start = linear_start - self.linear_end = linear_end - sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) - self.set_sigmas(sigmas) - return - - -def patch_all_precision(): - ldm_patched.ldm.modules.diffusionmodules.openaimodel.timestep_embedding = patched_timestep_embedding - ldm_patched.modules.model_sampling.ModelSamplingDiscrete._register_schedule = patched_register_schedule - return diff --git a/modules_forge/unet_patcher.py b/modules_forge/unet_patcher.py index 4f429bae..1d4b3122 100644 --- a/modules_forge/unet_patcher.py +++ b/modules_forge/unet_patcher.py @@ -1,7 +1,5 @@ import copy -import torch -from ldm_patched.ldm.modules.diffusionmodules.openaimodel import UNetModel, timestep_embedding, forward_timestep_embed, apply_control from ldm_patched.modules.model_patcher import ModelPatcher @@ -76,111 +74,3 @@ class UnetPatcher(ModelPatcher): for transformer_index in range(16): self.set_model_patch_replace(patch, target, block_name, number, transformer_index) return - - -def forge_unet_forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options=None, **kwargs): - if transformer_options is None: - transformer_options = {} - - transformer_options["original_shape"] = list(x.shape) - transformer_options["transformer_index"] = 0 - transformer_patches = transformer_options.get("patches", {}) - block_modifiers = transformer_options.get("block_modifiers", []) - - num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) - image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator) - time_context = kwargs.get("time_context", None) - - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] - t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) - emb = self.time_embed(t_emb) - - if self.num_classes is not None: - assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) - - h = x - for id, module in enumerate(self.input_blocks): - transformer_options["block"] = ("input", id) - - for block_modifier in block_modifiers: - h = block_modifier(h, 'before', transformer_options) - - h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, - num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) - h = apply_control(h, control, 'input') - - for block_modifier in block_modifiers: - h = block_modifier(h, 'after', transformer_options) - - if "input_block_patch" in transformer_patches: - patch = transformer_patches["input_block_patch"] - for p in patch: - h = p(h, transformer_options) - - hs.append(h) - if "input_block_patch_after_skip" in transformer_patches: - patch = transformer_patches["input_block_patch_after_skip"] - for p in patch: - h = p(h, transformer_options) - - transformer_options["block"] = ("middle", 0) - - for block_modifier in block_modifiers: - h = block_modifier(h, 'before', transformer_options) - - h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, - num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) - h = apply_control(h, control, 'middle') - - for block_modifier in block_modifiers: - h = block_modifier(h, 'after', transformer_options) - - for id, module in enumerate(self.output_blocks): - transformer_options["block"] = ("output", id) - hsp = hs.pop() - hsp = apply_control(hsp, control, 'output') - - if "output_block_patch" in transformer_patches: - patch = transformer_patches["output_block_patch"] - for p in patch: - h, hsp = p(h, hsp, transformer_options) - - h = torch.cat([h, hsp], dim=1) - del hsp - if len(hs) > 0: - output_shape = hs[-1].shape - else: - output_shape = None - - for block_modifier in block_modifiers: - h = block_modifier(h, 'before', transformer_options) - - h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, - time_context=time_context, num_video_frames=num_video_frames, - image_only_indicator=image_only_indicator) - - for block_modifier in block_modifiers: - h = block_modifier(h, 'after', transformer_options) - - transformer_options["block"] = ("last", 0) - - for block_modifier in block_modifiers: - h = block_modifier(h, 'before', transformer_options) - - if self.predict_codebook_ids: - h = self.id_predictor(h) - else: - h = self.out(h) - - for block_modifier in block_modifiers: - h = block_modifier(h, 'after', transformer_options) - - return h.type(x.dtype) - - -def patch_all(): - UNetModel.forward = forge_unet_forward