mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
1. Add an option to allow users to use UNet in fp8/gguf but lora in fp16. 2. All FP16 loras do not need patch. Others will only patch again when lora weight change. 3. FP8 unet + fp16 lora are available (somewhat only available) in Forge now. This also solves some “LoRA too subtle” problems. 4. Significantly speed up all gguf models (in Async mode) by using independent thread (CUDA stream) to compute and dequant at the same time, even when low-bit weights are already on GPU. 5. View “online lora” as a module similar to ControlLoRA so that it is moved to GPU together with model when sampling, achieving significant speedup and perfect low VRAM management simultaneously.
386 lines
14 KiB
Python
386 lines
14 KiB
Python
# Started from some codes from early ComfyUI and then 80% rewritten,
|
|
# mainly for supporting different special control methods in Forge
|
|
# Copyright Forge 2024
|
|
|
|
|
|
import torch
|
|
import math
|
|
import collections
|
|
|
|
from backend import memory_management
|
|
from backend.sampling.condition import Condition, compile_conditions, compile_weighted_conditions
|
|
from backend.operations import cleanup_cache
|
|
from backend.args import dynamic_args
|
|
from backend import utils
|
|
|
|
|
|
def get_area_and_mult(conds, x_in, timestep_in):
|
|
area = (x_in.shape[2], x_in.shape[3], 0, 0)
|
|
strength = 1.0
|
|
|
|
if 'timestep_start' in conds:
|
|
timestep_start = conds['timestep_start']
|
|
if timestep_in[0] > timestep_start:
|
|
return None
|
|
if 'timestep_end' in conds:
|
|
timestep_end = conds['timestep_end']
|
|
if timestep_in[0] < timestep_end:
|
|
return None
|
|
if 'area' in conds:
|
|
area = conds['area']
|
|
if 'strength' in conds:
|
|
strength = conds['strength']
|
|
|
|
input_x = x_in[:, :, area[2]:area[0] + area[2], area[3]:area[1] + area[3]]
|
|
|
|
if 'mask' in conds:
|
|
mask_strength = 1.0
|
|
if "mask_strength" in conds:
|
|
mask_strength = conds["mask_strength"]
|
|
mask = conds['mask']
|
|
assert (mask.shape[1] == x_in.shape[2])
|
|
assert (mask.shape[2] == x_in.shape[3])
|
|
mask = mask[:, area[2]:area[0] + area[2], area[3]:area[1] + area[3]] * mask_strength
|
|
mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1)
|
|
else:
|
|
mask = torch.ones_like(input_x)
|
|
mult = mask * strength
|
|
|
|
if 'mask' not in conds:
|
|
rr = 8
|
|
if area[2] != 0:
|
|
for t in range(rr):
|
|
mult[:, :, t:1 + t, :] *= ((1.0 / rr) * (t + 1))
|
|
if (area[0] + area[2]) < x_in.shape[2]:
|
|
for t in range(rr):
|
|
mult[:, :, area[0] - 1 - t:area[0] - t, :] *= ((1.0 / rr) * (t + 1))
|
|
if area[3] != 0:
|
|
for t in range(rr):
|
|
mult[:, :, :, t:1 + t] *= ((1.0 / rr) * (t + 1))
|
|
if (area[1] + area[3]) < x_in.shape[3]:
|
|
for t in range(rr):
|
|
mult[:, :, :, area[1] - 1 - t:area[1] - t] *= ((1.0 / rr) * (t + 1))
|
|
|
|
conditioning = {}
|
|
model_conds = conds["model_conds"]
|
|
for c in model_conds:
|
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
|
|
|
control = conds.get('control', None)
|
|
|
|
patches = None
|
|
cond_obj = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches'])
|
|
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
|
|
|
|
|
def cond_equal_size(c1, c2):
|
|
if c1 is c2:
|
|
return True
|
|
if c1.keys() != c2.keys():
|
|
return False
|
|
for k in c1:
|
|
if not c1[k].can_concat(c2[k]):
|
|
return False
|
|
return True
|
|
|
|
|
|
def can_concat_cond(c1, c2):
|
|
if c1.input_x.shape != c2.input_x.shape:
|
|
return False
|
|
|
|
def objects_concatable(obj1, obj2):
|
|
if (obj1 is None) != (obj2 is None):
|
|
return False
|
|
if obj1 is not None:
|
|
if obj1 is not obj2:
|
|
return False
|
|
return True
|
|
|
|
if not objects_concatable(c1.control, c2.control):
|
|
return False
|
|
|
|
if not objects_concatable(c1.patches, c2.patches):
|
|
return False
|
|
|
|
return cond_equal_size(c1.conditioning, c2.conditioning)
|
|
|
|
|
|
def cond_cat(c_list):
|
|
c_crossattn = []
|
|
c_concat = []
|
|
c_adm = []
|
|
crossattn_max_len = 0
|
|
|
|
temp = {}
|
|
for x in c_list:
|
|
for k in x:
|
|
cur = temp.get(k, [])
|
|
cur.append(x[k])
|
|
temp[k] = cur
|
|
|
|
out = {}
|
|
for k in temp:
|
|
conds = temp[k]
|
|
out[k] = conds[0].concat(conds[1:])
|
|
|
|
return out
|
|
|
|
|
|
def compute_cond_mark(cond_or_uncond, sigmas):
|
|
cond_or_uncond_size = int(sigmas.shape[0])
|
|
|
|
cond_mark = []
|
|
for cx in cond_or_uncond:
|
|
cond_mark += [cx] * cond_or_uncond_size
|
|
|
|
cond_mark = torch.Tensor(cond_mark).to(sigmas)
|
|
return cond_mark
|
|
|
|
|
|
def compute_cond_indices(cond_or_uncond, sigmas):
|
|
cl = int(sigmas.shape[0])
|
|
|
|
cond_indices = []
|
|
uncond_indices = []
|
|
for i, cx in enumerate(cond_or_uncond):
|
|
if cx == 0:
|
|
cond_indices += list(range(i * cl, (i + 1) * cl))
|
|
else:
|
|
uncond_indices += list(range(i * cl, (i + 1) * cl))
|
|
|
|
return cond_indices, uncond_indices
|
|
|
|
|
|
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options):
|
|
out_cond = torch.zeros_like(x_in)
|
|
out_count = torch.ones_like(x_in) * 1e-37
|
|
|
|
out_uncond = torch.zeros_like(x_in)
|
|
out_uncond_count = torch.ones_like(x_in) * 1e-37
|
|
|
|
COND = 0
|
|
UNCOND = 1
|
|
|
|
to_run = []
|
|
for x in cond:
|
|
p = get_area_and_mult(x, x_in, timestep)
|
|
if p is None:
|
|
continue
|
|
|
|
to_run += [(p, COND)]
|
|
if uncond is not None:
|
|
for x in uncond:
|
|
p = get_area_and_mult(x, x_in, timestep)
|
|
if p is None:
|
|
continue
|
|
|
|
to_run += [(p, UNCOND)]
|
|
|
|
while len(to_run) > 0:
|
|
first = to_run[0]
|
|
first_shape = first[0][0].shape
|
|
to_batch_temp = []
|
|
for x in range(len(to_run)):
|
|
if can_concat_cond(to_run[x][0], first[0]):
|
|
to_batch_temp += [x]
|
|
|
|
to_batch_temp.reverse()
|
|
to_batch = to_batch_temp[:1]
|
|
|
|
free_memory = memory_management.get_free_memory(x_in.device)
|
|
for i in range(1, len(to_batch_temp) + 1):
|
|
batch_amount = to_batch_temp[:len(to_batch_temp) // i]
|
|
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
|
|
if model.memory_required(input_shape) < free_memory:
|
|
to_batch = batch_amount
|
|
break
|
|
|
|
input_x = []
|
|
mult = []
|
|
c = []
|
|
cond_or_uncond = []
|
|
area = []
|
|
control = None
|
|
patches = None
|
|
for x in to_batch:
|
|
o = to_run.pop(x)
|
|
p = o[0]
|
|
input_x.append(p.input_x)
|
|
mult.append(p.mult)
|
|
c.append(p.conditioning)
|
|
area.append(p.area)
|
|
cond_or_uncond.append(o[1])
|
|
control = p.control
|
|
patches = p.patches
|
|
|
|
batch_chunks = len(cond_or_uncond)
|
|
input_x = torch.cat(input_x)
|
|
c = cond_cat(c)
|
|
timestep_ = torch.cat([timestep] * batch_chunks)
|
|
|
|
transformer_options = {}
|
|
if 'transformer_options' in model_options:
|
|
transformer_options = model_options['transformer_options'].copy()
|
|
|
|
if patches is not None:
|
|
if "patches" in transformer_options:
|
|
cur_patches = transformer_options["patches"].copy()
|
|
for p in patches:
|
|
if p in cur_patches:
|
|
cur_patches[p] = cur_patches[p] + patches[p]
|
|
else:
|
|
cur_patches[p] = patches[p]
|
|
else:
|
|
transformer_options["patches"] = patches
|
|
|
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
|
transformer_options["sigmas"] = timestep
|
|
|
|
transformer_options["cond_mark"] = compute_cond_mark(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
|
transformer_options["cond_indices"], transformer_options["uncond_indices"] = compute_cond_indices(cond_or_uncond=cond_or_uncond, sigmas=timestep)
|
|
|
|
c['transformer_options'] = transformer_options
|
|
|
|
if control is not None:
|
|
p = control
|
|
while p is not None:
|
|
p.transformer_options = transformer_options
|
|
p = p.previous_controlnet
|
|
control_cond = c.copy() # get_control may change items in this dict, so we need to copy it
|
|
c['control'] = control.get_control(input_x, timestep_, control_cond, len(cond_or_uncond))
|
|
c['control_model'] = control
|
|
|
|
if 'model_function_wrapper' in model_options:
|
|
output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks)
|
|
else:
|
|
output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks)
|
|
del input_x
|
|
|
|
for o in range(batch_chunks):
|
|
if cond_or_uncond[o] == COND:
|
|
out_cond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
out_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
else:
|
|
out_uncond[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o]
|
|
out_uncond_count[:, :, area[o][2]:area[o][0] + area[o][2], area[o][3]:area[o][1] + area[o][3]] += mult[o]
|
|
del mult
|
|
|
|
out_cond /= out_count
|
|
del out_count
|
|
out_uncond /= out_uncond_count
|
|
del out_uncond_count
|
|
return out_cond, out_uncond
|
|
|
|
|
|
def sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None, return_full=False):
|
|
edit_strength = sum((item['strength'] if 'strength' in item else 1) for item in cond)
|
|
|
|
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
|
uncond_ = None
|
|
else:
|
|
uncond_ = uncond
|
|
|
|
for fn in model_options.get("sampler_pre_cfg_function", []):
|
|
model, cond, uncond_, x, timestep, model_options = fn(model, cond, uncond_, x, timestep, model_options)
|
|
|
|
cond_pred, uncond_pred = calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options)
|
|
|
|
if "sampler_cfg_function" in model_options:
|
|
args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep,
|
|
"cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options}
|
|
cfg_result = x - model_options["sampler_cfg_function"](args)
|
|
elif not math.isclose(edit_strength, 1.0):
|
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale * edit_strength
|
|
else:
|
|
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
|
|
|
for fn in model_options.get("sampler_post_cfg_function", []):
|
|
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
|
"sigma": timestep, "model_options": model_options, "input": x}
|
|
cfg_result = fn(args)
|
|
|
|
if return_full:
|
|
return cfg_result, cond_pred, uncond_pred
|
|
|
|
return cfg_result
|
|
|
|
|
|
def sampling_function(self, denoiser_params, cond_scale, cond_composition):
|
|
unet_patcher = self.inner_model.inner_model.forge_objects.unet
|
|
model = unet_patcher.model
|
|
control = unet_patcher.controlnet_linked_list
|
|
extra_concat_condition = unet_patcher.extra_concat_condition
|
|
x = denoiser_params.x
|
|
timestep = denoiser_params.sigma
|
|
uncond = compile_conditions(denoiser_params.text_uncond)
|
|
cond = compile_weighted_conditions(denoiser_params.text_cond, cond_composition)
|
|
model_options = unet_patcher.model_options
|
|
seed = self.p.seeds[0]
|
|
|
|
if extra_concat_condition is not None:
|
|
image_cond_in = extra_concat_condition
|
|
else:
|
|
image_cond_in = denoiser_params.image_cond
|
|
|
|
if isinstance(image_cond_in, torch.Tensor):
|
|
if image_cond_in.shape[0] == x.shape[0] \
|
|
and image_cond_in.shape[2] == x.shape[2] \
|
|
and image_cond_in.shape[3] == x.shape[3]:
|
|
for i in range(len(uncond)):
|
|
uncond[i]['model_conds']['c_concat'] = Condition(image_cond_in)
|
|
for i in range(len(cond)):
|
|
cond[i]['model_conds']['c_concat'] = Condition(image_cond_in)
|
|
|
|
if control is not None:
|
|
for h in cond + uncond:
|
|
h['control'] = control
|
|
|
|
for modifier in model_options.get('conditioning_modifiers', []):
|
|
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed)
|
|
|
|
denoised, cond_pred, uncond_pred = sampling_function_inner(model, x, timestep, uncond, cond, cond_scale, model_options, seed, return_full=True)
|
|
return denoised, cond_pred, uncond_pred
|
|
|
|
|
|
def sampling_prepare(unet, x):
|
|
B, C, H, W = x.shape
|
|
|
|
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required)
|
|
|
|
unet_inference_memory = memory_estimation_function([B * 2, C, H, W])
|
|
additional_inference_memory = unet.extra_preserved_memory_during_sampling
|
|
additional_model_patchers = unet.extra_model_patchers_during_sampling
|
|
|
|
if unet.controlnet_linked_list is not None:
|
|
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype())
|
|
additional_model_patchers += unet.controlnet_linked_list.get_models()
|
|
|
|
if dynamic_args.get('online_lora', False):
|
|
lora_memory = utils.nested_compute_size(unet.lora_loader.patches)
|
|
additional_inference_memory += lora_memory
|
|
|
|
memory_management.load_models_gpu(
|
|
models=[unet] + additional_model_patchers,
|
|
memory_required=unet_inference_memory + additional_inference_memory)
|
|
|
|
if dynamic_args.get('online_lora', False):
|
|
utils.nested_move_to_device(unet.lora_loader.patches, device=unet.current_device)
|
|
|
|
unet.lora_loader.patches = {}
|
|
|
|
real_model = unet.model
|
|
|
|
percent_to_timestep_function = lambda p: real_model.predictor.percent_to_sigma(p)
|
|
|
|
for cnet in unet.list_controlnets():
|
|
cnet.pre_run(real_model, percent_to_timestep_function)
|
|
|
|
return
|
|
|
|
|
|
def sampling_cleanup(unet):
|
|
for cnet in unet.list_controlnets():
|
|
cnet.cleanup()
|
|
cleanup_cache()
|
|
return
|