This commit is contained in:
lllyasviel
2024-01-31 13:16:41 -08:00
parent 86bd258a8e
commit 67c38f9294
11 changed files with 167 additions and 68 deletions

View File

@@ -186,6 +186,26 @@ class ConditioningSetAreaPercentage:
c.append(n)
return (c, )
class ConditioningSetAreaStrength:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", ),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, conditioning, strength):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['strength'] = strength
c.append(n)
return (c, )
class ConditioningSetMask:
@classmethod
def INPUT_TYPES(s):
@@ -1756,6 +1776,7 @@ NODE_CLASS_MAPPINGS = {
"ConditioningConcat": ConditioningConcat,
"ConditioningSetArea": ConditioningSetArea,
"ConditioningSetAreaPercentage": ConditioningSetAreaPercentage,
"ConditioningSetAreaStrength": ConditioningSetAreaStrength,
"ConditioningSetMask": ConditioningSetMask,
"KSamplerAdvanced": KSamplerAdvanced,
"SetLatentNoiseMask": SetLatentNoiseMask,

View File

@@ -825,6 +825,7 @@ class UNetModel(nn.Module):
transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})
block_modifiers = transformer_options.get("block_modifiers", [])
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
@@ -844,8 +845,16 @@ class UNetModel(nn.Module):
h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'input')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
@@ -858,9 +867,15 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
@@ -878,9 +893,26 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = h.type(x.dtype)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
transformer_options["block"] = ("last", 0)
for block_modifier in block_modifiers:
h = block_modifier(h, 'before', transformer_options)
if self.predict_codebook_ids:
return self.id_predictor(h)
h = self.id_predictor(h)
else:
return self.out(h)
h = self.out(h)
for block_modifier in block_modifiers:
h = block_modifier(h, 'after', transformer_options)
return h.type(x.dtype)

View File

@@ -225,19 +225,13 @@ class CheckpointFunction(torch.autograd.Function):
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
# Consistent with Kohya to reduce differences between model training and inference.
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half
)
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:

View File

@@ -11,6 +11,9 @@ import ldm_patched.controlnet.cldm
import ldm_patched.t2ia.adapter
compute_controlnet_weighting = None
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
#print(current_batch_size, target_batch_size)
@@ -114,6 +117,10 @@ class ControlBase:
x = x.to(output_dtype)
out[key].append(x)
if compute_controlnet_weighting is not None:
out = compute_controlnet_weighting(out, self)
if control_prev is not None:
for x in ['input', 'middle', 'output']:
o = out[x]

View File

@@ -1,3 +1,4 @@
import time
import psutil
from enum import Enum
from ldm_patched.modules.args_parser import args
@@ -42,8 +43,6 @@ if args.directml is not None:
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex
@@ -128,6 +127,9 @@ try:
except:
OOM_EXCEPTION = Exception
if directml_enabled:
OOM_EXCEPTION = Exception
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
@@ -376,6 +378,8 @@ def free_memory(memory_required, device, keep_loaded=[]):
def load_models_gpu(models, memory_required=0):
global vram_state
execution_start_time = time.perf_counter()
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
@@ -390,7 +394,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
print(f"To load target model {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@@ -398,9 +402,14 @@ def load_models_gpu(models, memory_required=0):
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) skipped. Freeing memory has taken {moving_time:.2f} seconds')
return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
print(f"Begin to load {len(models_to_load)} model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load:
@@ -433,6 +442,11 @@ def load_models_gpu(models, memory_required=0):
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
current_loaded_models.insert(0, loaded_model)
moving_time = time.perf_counter() - execution_start_time
if moving_time > 0.1:
print(f'Moving model(s) has taken {moving_time:.2f} seconds')
return

View File

@@ -1,5 +1,24 @@
import torch
import ldm_patched.modules.model_management
import contextlib
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return
def cast_bias_weight(s, input):
bias = None

View File

@@ -126,6 +126,29 @@ def cond_cat(c_list):
return out
def compute_cond_mark(cond_or_uncond, sigmas):
cond_or_uncond_size = int(sigmas.shape[0])
cond_mark = []
for cx in cond_or_uncond:
cond_mark += [cx] * cond_or_uncond_size
cond_mark = torch.Tensor(cond_mark).to(sigmas)
return cond_mark
def compute_cond_indices(cond_or_uncond, sigmas):
cl = int(sigmas.shape[0])
cond_indices = []
uncond_indices = []
for i, cx in enumerate(cond_or_uncond):
if cx == 0:
cond_indices += list(range(i * cl, (i + 1) * cl))
else:
uncond_indices += list(range(i * cl, (i + 1) * cl))
return cond_indices, uncond_indices
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in) * 1e-37
@@ -193,9 +216,6 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
c = cond_cat(c)
timestep_ = torch.cat([timestep] * batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
transformer_options = {}
if 'transformer_options' in model_options:
transformer_options = model_options['transformer_options'].copy()
@@ -214,8 +234,18 @@ def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["sigmas"] = timestep
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
c['transformer_options'] = transformer_options
if control is not None:
p = control
while p is not None:
p.transformer_options = transformer_options
p = p.previous_controlnet
c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond))
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
else:

View File

@@ -462,7 +462,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model.load_model_weights(sd, "model.diffusion_model.")
if output_vae:
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd)

View File

@@ -8,6 +8,7 @@ import zipfile
from . import model_management
import ldm_patched.modules.clip_model
import json
from transformers import CLIPTextModel, CLIPTextConfig, modeling_utils
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
@@ -74,11 +75,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
with open(textmodel_json_config) as f:
config = json.load(f)
config = CLIPTextConfig.from_json_file(textmodel_json_config)
self.num_layers = config.num_hidden_layers
self.transformer = model_class(config, dtype, device, ldm_patched.modules.ops.manual_cast)
self.num_layers = self.transformer.num_layers
with ldm_patched.modules.ops.use_patched_ops(ldm_patched.modules.ops.manual_cast):
with modeling_utils.no_init_weights():
self.transformer = CLIPTextModel(config)
if dtype is not None:
self.transformer.to(dtype)
self.transformer.text_model.embeddings.to(torch.float32)
self.max_length = max_length
if freeze:
@@ -169,16 +176,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if tokens[x, y] == max_token:
break
outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
outputs = self.transformer(input_ids=tokens, attention_mask=attention_mask,
output_hidden_states=self.layer == "hidden")
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":
z = outputs[0]
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs[1]
z = outputs.hidden_states[self.layer_idx]
if self.layer_norm_hidden_state:
z = self.transformer.text_model.final_layer_norm(z)
if outputs[2] is not None:
pooled_output = outputs[2].float()
if hasattr(outputs, "pooler_output"):
pooled_output = outputs.pooler_output.float()
else:
pooled_output = None

View File

@@ -21,6 +21,7 @@ class BASE:
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
vae_key_prefix = ["first_stage_model."]
manual_cast_dtype = None

View File

@@ -5,47 +5,16 @@ supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
folder_names_and_paths = {}
base_path = os.getcwd()
models_dir = os.path.join(base_path, "models")
folder_names_and_paths["checkpoints"] = ([os.path.join(models_dir, "checkpoints")], supported_pt_extensions)
folder_names_and_paths["configs"] = ([os.path.join(models_dir, "configs")], [".yaml"])
folder_names_and_paths["loras"] = ([os.path.join(models_dir, "loras")], supported_pt_extensions)
folder_names_and_paths["vae"] = ([os.path.join(models_dir, "vae")], supported_pt_extensions)
folder_names_and_paths["clip"] = ([os.path.join(models_dir, "clip")], supported_pt_extensions)
folder_names_and_paths["unet"] = ([os.path.join(models_dir, "unet")], supported_pt_extensions)
folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision")], supported_pt_extensions)
folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions)
folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions)
folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"])
folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions)
folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions)
folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions)
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], [])
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")], supported_pt_extensions)
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
output_directory = os.path.join(os.getcwd(), "output")
temp_directory = os.path.join(os.getcwd(), "temp")
input_directory = os.path.join(os.getcwd(), "input")
user_directory = os.path.join(os.getcwd(), "user")
# Will be assigned by modules.paths
base_path = None
models_dir = None
output_directory = None
temp_directory = None
input_directory = None
user_directory = None
filename_list_cache = {}
if not os.path.exists(input_directory):
try:
pass # os.makedirs(input_directory)
except:
print("Failed to create input directory")
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir