mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-01-26 19:09:45 +00:00
Update webui.py
i Update initialization.py initialization initialization Update initialization.py i i Update sd_samplers_common.py Update sd_hijack.py i Update sd_models.py Update sd_models.py Update forge_loader.py Update sd_models.py i Update sd_model.py i Update sd_models.py Create sd_model.py i i Update sd_models.py i Update sd_models.py Update sd_models.py i i Update sd_samplers_common.py i Update sd_models.py Update sd_models.py Update sd_samplers_common.py Update sd_models.py Update sd_models.py Update sd_models.py Update sd_models.py Update sd_samplers_common.py i Update shared_options.py Update prompt_parser.py Update sd_hijack_unet.py i Update sd_models.py Update sd_models.py Update sd_models.py Update devices.py i Update sd_vae.py Update sd_models.py Update processing.py Update ui_settings.py Update sd_models_xl.py i i Update sd_samplers_kdiffusion.py Update sd_samplers_timesteps.py Update ui_settings.py Update cmd_args.py Update cmd_args.py Update initialization.py Update shared_options.py Update initialization.py Update shared_options.py i Update cmd_args.py Update initialization.py Update initialization.py Update initialization.py Update cmd_args.py Update cmd_args.py Update sd_hijack.py
This commit is contained in:
@@ -179,26 +179,11 @@ def manual_cast(target_dtype):
|
||||
|
||||
|
||||
def autocast(disable=False):
|
||||
if disable:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if fp8 and device==cpu:
|
||||
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
|
||||
|
||||
if fp8 and dtype_inference == torch.float32:
|
||||
return manual_cast(dtype)
|
||||
|
||||
if dtype == torch.float32 or dtype_inference == torch.float32:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
if has_xpu() or has_mps() or cuda_no_autocast():
|
||||
return manual_cast(dtype)
|
||||
|
||||
return torch.autocast("cuda")
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def without_autocast(disable=False):
|
||||
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
class NansException(Exception):
|
||||
|
||||
@@ -12,6 +12,10 @@ def imports():
|
||||
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
|
||||
from modules_forge.initialization import initialize_forge
|
||||
initialize_forge()
|
||||
startup_timer.record("initialize forge")
|
||||
|
||||
import torch # noqa: F401
|
||||
startup_timer.record("import torch")
|
||||
import pytorch_lightning # noqa: F401
|
||||
|
||||
@@ -16,7 +16,7 @@ from skimage import exposure
|
||||
from typing import Any
|
||||
|
||||
import modules.sd_hijack
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules import devices, prompt_parser, masking, sd_samplers, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
|
||||
from modules.rng import slerp # noqa: F401
|
||||
from modules.sd_hijack import model_hijack
|
||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||
@@ -627,44 +627,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
|
||||
|
||||
for i in range(batch.shape[0]):
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if check_for_nans:
|
||||
|
||||
try:
|
||||
devices.test_for_nans(sample, "vae")
|
||||
except devices.NansException as e:
|
||||
if shared.opts.auto_vae_precision_bfloat16:
|
||||
autofix_dtype = torch.bfloat16
|
||||
autofix_dtype_text = "bfloat16"
|
||||
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
|
||||
autofix_dtype_comment = ""
|
||||
elif shared.opts.auto_vae_precision:
|
||||
autofix_dtype = torch.float32
|
||||
autofix_dtype_text = "32-bit float"
|
||||
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
|
||||
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
|
||||
else:
|
||||
raise e
|
||||
|
||||
if devices.dtype_vae == autofix_dtype:
|
||||
raise e
|
||||
|
||||
errors.print_error_explanation(
|
||||
"A tensor with all NaNs was produced in VAE.\n"
|
||||
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
|
||||
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
|
||||
)
|
||||
|
||||
devices.dtype_vae = autofix_dtype
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
batch = batch.to(devices.dtype_vae)
|
||||
|
||||
sample = decode_first_stage(model, batch[i:i + 1])[0]
|
||||
|
||||
if target_device is not None:
|
||||
sample = sample.to(target_device)
|
||||
|
||||
samples.append(sample)
|
||||
samples.append(sample.to(target_device))
|
||||
|
||||
return samples
|
||||
|
||||
@@ -940,8 +903,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
|
||||
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
|
||||
|
||||
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
|
||||
|
||||
if p.scripts is not None:
|
||||
ps = scripts.PostSampleArgs(samples_ddim)
|
||||
@@ -960,9 +922,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
||||
|
||||
del samples_ddim
|
||||
|
||||
if lowvram.is_enabled(shared.sd_model):
|
||||
lowvram.send_everything_to_cpu()
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
state.nextjob()
|
||||
@@ -1255,7 +1214,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
|
||||
image = np.moveaxis(image, 2, 0)
|
||||
image = torch.from_numpy(np.expand_dims(image, axis=0))
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
@@ -1339,7 +1298,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
batch_images.append(image)
|
||||
|
||||
decoded_samples = torch.from_numpy(np.array(batch_images))
|
||||
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
|
||||
decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
@@ -1444,7 +1403,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
|
||||
if shared.opts.hires_fix_use_firstpass_conds:
|
||||
self.calculate_hr_conds()
|
||||
|
||||
elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
||||
elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
|
||||
with devices.autocast():
|
||||
extra_networks.activate(self, self.hr_extra_network_data)
|
||||
|
||||
@@ -1631,7 +1590,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
||||
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
|
||||
|
||||
image = torch.from_numpy(batch_images)
|
||||
image = image.to(shared.device, dtype=devices.dtype_vae)
|
||||
image = image.to(shared.device, dtype=torch.float32)
|
||||
|
||||
if opts.sd_vae_encode_method != 'Full':
|
||||
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
|
||||
|
||||
@@ -276,6 +276,12 @@ class DictWithShape(dict):
|
||||
def shape(self):
|
||||
return self["crossattn"].shape
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for k in self.keys():
|
||||
if isinstance(self[k], torch.Tensor):
|
||||
self[k] = self[k].to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
|
||||
param = c[0][0].cond
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
from modules import prompt_parser, devices, sd_hijack
|
||||
from modules.shared import opts
|
||||
|
||||
@@ -210,6 +211,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
|
||||
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
|
||||
"""
|
||||
|
||||
if hasattr(self.wrapped, 'patcher'):
|
||||
model_management.load_model_gpu(self.wrapped.patcher)
|
||||
|
||||
if opts.use_old_emphasis_implementation:
|
||||
import modules.sd_hijack_clip_old
|
||||
return modules.sd_hijack_clip_old.forward_old(self, texts)
|
||||
|
||||
@@ -13,10 +13,16 @@ import ldm.modules.midas as midas
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
|
||||
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, sd_hijack, patches
|
||||
from modules.timer import Timer
|
||||
import tomesd
|
||||
import numpy as np
|
||||
from modules_forge import forge_loader
|
||||
import modules_forge.ops as forge_ops
|
||||
from ldm_patched.modules.ops import manual_cast
|
||||
from ldm_patched.modules import model_management as model_management
|
||||
import ldm_patched.modules.model_patcher
|
||||
|
||||
|
||||
model_dir = "Stable-diffusion"
|
||||
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
||||
@@ -366,10 +372,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||
timer.record("calculate hash")
|
||||
|
||||
if devices.fp8:
|
||||
# prevent model to load state dict in fp8
|
||||
model.half()
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
|
||||
|
||||
@@ -379,13 +381,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
model.is_sdxl = hasattr(model, 'conditioner')
|
||||
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
|
||||
model.is_sd1 = not model.is_sdxl and not model.is_sd2
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
|
||||
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in model.state_dict().keys()
|
||||
if model.is_sdxl:
|
||||
sd_models_xl.extend_sdxl(model)
|
||||
|
||||
if model.is_ssd:
|
||||
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
|
||||
|
||||
if shared.opts.sd_checkpoint_cache > 0:
|
||||
# cache newly loaded model
|
||||
checkpoints_loaded[checkpoint_info] = state_dict.copy()
|
||||
@@ -395,65 +394,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
|
||||
|
||||
del state_dict
|
||||
|
||||
if shared.cmd_opts.opt_channelslast:
|
||||
model.to(memory_format=torch.channels_last)
|
||||
timer.record("apply channels_last")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
model.float()
|
||||
model.alphas_cumprod_original = model.alphas_cumprod
|
||||
devices.dtype_unet = torch.float32
|
||||
timer.record("apply float()")
|
||||
else:
|
||||
vae = model.first_stage_model
|
||||
depth_model = getattr(model, 'depth_model', None)
|
||||
|
||||
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
|
||||
if shared.cmd_opts.no_half_vae:
|
||||
model.first_stage_model = None
|
||||
# with --upcast-sampling, don't convert the depth model weights to float16
|
||||
if shared.cmd_opts.upcast_sampling and depth_model:
|
||||
model.depth_model = None
|
||||
|
||||
alphas_cumprod = model.alphas_cumprod
|
||||
model.alphas_cumprod = None
|
||||
model.half()
|
||||
model.alphas_cumprod = alphas_cumprod
|
||||
model.alphas_cumprod_original = alphas_cumprod
|
||||
model.first_stage_model = vae
|
||||
if depth_model:
|
||||
model.depth_model = depth_model
|
||||
|
||||
devices.dtype_unet = torch.float16
|
||||
timer.record("apply half()")
|
||||
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'fp16_weight'):
|
||||
del module.fp16_weight
|
||||
if hasattr(module, 'fp16_bias'):
|
||||
del module.fp16_bias
|
||||
|
||||
if check_fp8(model):
|
||||
devices.fp8 = True
|
||||
first_stage = model.first_stage_model
|
||||
model.first_stage_model = None
|
||||
for module in model.modules():
|
||||
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
|
||||
if shared.opts.cache_fp16_weight:
|
||||
module.fp16_weight = module.weight.data.clone().cpu().half()
|
||||
if module.bias is not None:
|
||||
module.fp16_bias = module.bias.data.clone().cpu().half()
|
||||
module.to(torch.float8_e4m3fn)
|
||||
model.first_stage_model = first_stage
|
||||
timer.record("apply fp8")
|
||||
else:
|
||||
devices.fp8 = False
|
||||
|
||||
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
|
||||
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
timer.record("apply dtype to VAE")
|
||||
|
||||
# clean up cache if limit is reached
|
||||
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
|
||||
checkpoints_loaded.popitem(last=False)
|
||||
@@ -614,34 +554,6 @@ def get_empty_cond(sd_model):
|
||||
return sd_model.cond_stage_model([""])
|
||||
|
||||
|
||||
def send_model_to_cpu(m):
|
||||
if m.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
m.to(devices.cpu)
|
||||
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def model_target_device(m):
|
||||
if lowvram.is_needed(m):
|
||||
return devices.cpu
|
||||
else:
|
||||
return devices.device
|
||||
|
||||
|
||||
def send_model_to_device(m):
|
||||
lowvram.apply(m)
|
||||
|
||||
if not m.lowvram:
|
||||
m.to(shared.device)
|
||||
|
||||
|
||||
def send_model_to_trash(m):
|
||||
m.to(device="meta")
|
||||
devices.torch_gc()
|
||||
|
||||
|
||||
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
from modules import sd_hijack
|
||||
checkpoint_info = checkpoint_info or select_checkpoint()
|
||||
@@ -649,7 +561,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
timer = Timer()
|
||||
|
||||
if model_data.sd_model:
|
||||
send_model_to_trash(model_data.sd_model)
|
||||
model_data.sd_model = None
|
||||
devices.torch_gc()
|
||||
|
||||
@@ -670,12 +581,21 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
|
||||
timer.record("load config")
|
||||
|
||||
if hasattr(sd_config.model.params, 'network_config'):
|
||||
sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
if hasattr(sd_config.model.params, 'unet_config'):
|
||||
sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
if hasattr(sd_config.model.params, 'first_stage_config'):
|
||||
sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject'
|
||||
|
||||
print(f"Creating model from config: {checkpoint_config}")
|
||||
|
||||
sd_model = None
|
||||
try:
|
||||
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
with forge_ops.use_patched_ops(manual_cast):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
except Exception as e:
|
||||
@@ -684,28 +604,45 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
||||
if sd_model is None:
|
||||
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
|
||||
|
||||
with sd_disable_initialization.InitializeOnMeta():
|
||||
with forge_ops.use_patched_ops(manual_cast):
|
||||
sd_model = instantiate_from_config(sd_config.model)
|
||||
|
||||
sd_model.used_config = checkpoint_config
|
||||
|
||||
timer.record("create model")
|
||||
|
||||
if shared.cmd_opts.no_half:
|
||||
weight_dtype_conversion = None
|
||||
else:
|
||||
weight_dtype_conversion = {
|
||||
'first_stage_model': None,
|
||||
'alphas_cumprod': None,
|
||||
'': torch.float16,
|
||||
}
|
||||
state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')}
|
||||
state_dict_for_forge = {k: v for k, v in state_dict.items()}
|
||||
del state_dict
|
||||
|
||||
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
|
||||
unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge)
|
||||
sd_model.first_stage_model = vae_patcher.first_stage_model
|
||||
sd_model.model.diffusion_model = unet_patcher.model.diffusion_model
|
||||
sd_model.unet_patcher = unet_patcher
|
||||
sd_model.model.diffusion_model.patcher = unet_patcher
|
||||
sd_model.vae_patcher = vae_patcher
|
||||
sd_model.first_stage_model.patcher = vae_patcher
|
||||
timer.record("create unet patcher")
|
||||
del state_dict_for_forge
|
||||
|
||||
load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer)
|
||||
del state_dict_for_a1111
|
||||
timer.record("load weights from state dict")
|
||||
|
||||
send_model_to_device(sd_model)
|
||||
timer.record("move model to device")
|
||||
current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model
|
||||
clip_load_device = model_management.text_encoder_device()
|
||||
clip_offload_device = model_management.text_encoder_offload_device()
|
||||
clip_dtype = model_management.text_encoder_dtype()
|
||||
|
||||
current_clip.to(clip_dtype)
|
||||
clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher(
|
||||
current_clip,
|
||||
load_device=clip_load_device,
|
||||
offload_device=clip_offload_device
|
||||
)
|
||||
sd_model.clip_patcher = clip_patcher
|
||||
current_clip.patcher = clip_patcher
|
||||
timer.record("create clip patcher")
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
@@ -752,17 +689,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
|
||||
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
|
||||
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
|
||||
model_data.loaded_sd_models.pop()
|
||||
send_model_to_trash(loaded_model)
|
||||
timer.record("send model to trash")
|
||||
|
||||
if shared.opts.sd_checkpoints_keep_in_cpu:
|
||||
send_model_to_cpu(sd_model)
|
||||
timer.record("send model to cpu")
|
||||
|
||||
if already_loaded is not None:
|
||||
send_model_to_device(already_loaded)
|
||||
timer.record("send model to device")
|
||||
|
||||
model_data.set_sd_model(already_loaded, already_loaded=True)
|
||||
|
||||
if not SkipWritingToConfig.skip:
|
||||
@@ -816,7 +744,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
|
||||
if sd_model is not None:
|
||||
sd_unet.apply_unet("None")
|
||||
send_model_to_cpu(sd_model)
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
|
||||
@@ -827,7 +754,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
|
||||
if sd_model is None or checkpoint_config != sd_model.used_config:
|
||||
if sd_model is not None:
|
||||
send_model_to_trash(sd_model)
|
||||
sd_model = None
|
||||
|
||||
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
|
||||
return model_data.sd_model
|
||||
@@ -857,12 +784,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
|
||||
return sd_model
|
||||
|
||||
|
||||
def unload_model_weights(sd_model=None, info=None):
|
||||
send_model_to_cpu(sd_model or shared.sd_model)
|
||||
|
||||
return sd_model
|
||||
|
||||
|
||||
def apply_token_merging(sd_model, token_merging_ratio):
|
||||
"""
|
||||
Applies speed and memory optimizations from tomesd.
|
||||
|
||||
@@ -8,8 +8,12 @@ import sgm.modules.diffusionmodules.discretizer
|
||||
from modules import devices, shared, prompt_parser
|
||||
from modules import torch_utils
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
|
||||
|
||||
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
|
||||
model_management.load_model_gpu(self.clip_patcher)
|
||||
|
||||
for embedder in self.conditioner.embedders:
|
||||
embedder.ucg_rate = 0.0
|
||||
|
||||
@@ -18,7 +22,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
|
||||
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
|
||||
|
||||
devices_args = dict(device=devices.device, dtype=devices.dtype)
|
||||
devices_args = dict(device=self.clip_patcher.current_device, dtype=model_management.text_encoder_dtype())
|
||||
|
||||
sdxl_conds = {
|
||||
"txt": batch,
|
||||
@@ -35,11 +39,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
|
||||
|
||||
|
||||
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
|
||||
sd = self.model.state_dict()
|
||||
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
|
||||
if diffusion_model_input is not None:
|
||||
if diffusion_model_input.shape[1] == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
if self.model.diffusion_model.in_channels == 9:
|
||||
x = torch.cat([x] + cond['c_concat'], dim=1)
|
||||
|
||||
return self.model(x, t, cond)
|
||||
|
||||
|
||||
@@ -173,6 +173,15 @@ class CFGDenoiser(torch.nn.Module):
|
||||
uncond = pad_cond(uncond, num_repeats, empty)
|
||||
self.padded_cond_uncond = True
|
||||
|
||||
unet_dtype = self.inner_model.inner_model.unet_patcher.model.model_config.unet_config['dtype']
|
||||
x_input_dtype = x_in.dtype
|
||||
|
||||
x_in = x_in.to(unet_dtype)
|
||||
sigma_in = sigma_in.to(unet_dtype)
|
||||
image_cond_in = image_cond_in.to(unet_dtype)
|
||||
tensor = tensor.to(unet_dtype)
|
||||
uncond = uncond.to(unet_dtype)
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
@@ -211,6 +220,8 @@ class CFGDenoiser(torch.nn.Module):
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
x_out = x_out.to(x_input_dtype)
|
||||
|
||||
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
|
||||
cfg_denoised_callback(denoised_params)
|
||||
|
||||
|
||||
@@ -39,9 +39,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
|
||||
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
|
||||
approximation = approximation_indexes.get(opts.show_progress_type, 0)
|
||||
|
||||
from modules import lowvram
|
||||
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
|
||||
if approximation == 0:
|
||||
approximation = 1
|
||||
|
||||
if approximation == 2:
|
||||
@@ -54,8 +52,8 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
|
||||
else:
|
||||
if model is None:
|
||||
model = shared.sd_model
|
||||
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
|
||||
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
|
||||
sample = model.unet_patcher.model.model_config.latent_format.process_out(sample)
|
||||
x_sample = model.vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
|
||||
return x_sample
|
||||
|
||||
@@ -71,7 +69,6 @@ def single_sample_to_image(sample, approximation=None):
|
||||
|
||||
|
||||
def decode_first_stage(model, x):
|
||||
x = x.to(devices.dtype_vae)
|
||||
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
|
||||
return samples_to_images_tensor(x, approx_index, model)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
import ldm_patched.modules.model_management
|
||||
|
||||
|
||||
samplers_k_diffusion = [
|
||||
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
|
||||
@@ -139,6 +141,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return sigmas
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = 0
|
||||
unet_patcher = self.model_wrap.inner_model.unet_patcher
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
@@ -193,6 +201,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = 0
|
||||
unet_patcher = self.model_wrap.inner_model.unet_patcher
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
|
||||
steps = steps or p.steps
|
||||
|
||||
sigmas = self.get_sigmas(p, steps)
|
||||
|
||||
@@ -34,7 +34,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
|
||||
|
||||
def sigma_to_t(self, sigma, quantize=None):
|
||||
log_sigma = sigma.log()
|
||||
dists = log_sigma - self.log_sigmas[:, None]
|
||||
dists = log_sigma - self.log_sigmas.to(sigma)[:, None]
|
||||
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
|
||||
|
||||
from modules.shared import opts
|
||||
import modules.shared as shared
|
||||
import ldm_patched.modules.model_management
|
||||
|
||||
|
||||
samplers_timesteps = [
|
||||
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
|
||||
@@ -95,6 +97,12 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
return timesteps
|
||||
|
||||
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = 0
|
||||
unet_patcher = self.model_wrap.inner_model.unet_patcher
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
|
||||
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
|
||||
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
@@ -139,6 +147,12 @@ class CompVisSampler(sd_samplers_common.Sampler):
|
||||
return samples
|
||||
|
||||
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
|
||||
inference_memory = 0
|
||||
unet_patcher = self.model_wrap.inner_model.unet_patcher
|
||||
ldm_patched.modules.model_management.load_models_gpu(
|
||||
[unet_patcher],
|
||||
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
|
||||
|
||||
steps = steps or p.steps
|
||||
timesteps = self.get_timesteps(p, steps)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
|
||||
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, sd_hijack, hashes
|
||||
|
||||
import glob
|
||||
from copy import deepcopy
|
||||
@@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
|
||||
# don't call this from outside
|
||||
def _load_vae_dict(model, vae_dict_1):
|
||||
model.first_stage_model.load_state_dict(vae_dict_1)
|
||||
model.first_stage_model.to(devices.dtype_vae)
|
||||
|
||||
|
||||
def clear_loaded_vae():
|
||||
@@ -263,20 +262,12 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
||||
if loaded_vae_file == vae_file:
|
||||
return
|
||||
|
||||
if sd_model.lowvram:
|
||||
lowvram.send_everything_to_cpu()
|
||||
else:
|
||||
sd_model.to(devices.cpu)
|
||||
|
||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||
|
||||
load_vae(sd_model, vae_file, vae_source)
|
||||
|
||||
sd_hijack.model_hijack.hijack(sd_model)
|
||||
|
||||
if not sd_model.lowvram:
|
||||
sd_model.to(devices.device)
|
||||
|
||||
script_callbacks.model_loaded_callback(sd_model)
|
||||
|
||||
print("VAE weights loaded.")
|
||||
|
||||
42
modules_forge/forge_loader.py
Normal file
42
modules_forge/forge_loader.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
|
||||
from ldm_patched.modules import model_management
|
||||
from ldm_patched.modules import model_detection
|
||||
|
||||
from ldm_patched.modules.sd import VAE
|
||||
import ldm_patched.modules.model_patcher
|
||||
import ldm_patched.modules.utils
|
||||
|
||||
|
||||
def load_unet_and_vae(sd):
|
||||
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
||||
load_device = model_management.get_torch_device()
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
|
||||
model_config.set_manual_cast(manual_cast_dtype)
|
||||
|
||||
if model_config is None:
|
||||
raise RuntimeError("ERROR: Could not detect model type of")
|
||||
|
||||
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||
model = model_config.get_model(sd, "model.diffusion_model.", device=initial_load_device)
|
||||
model.load_model_weights(sd, "model.diffusion_model.")
|
||||
|
||||
model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model,
|
||||
load_device=load_device,
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
current_device=initial_load_device)
|
||||
|
||||
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
|
||||
vae_sd = model_config.process_vae_state_dict(vae_sd)
|
||||
vae_patcher = VAE(sd=vae_sd)
|
||||
|
||||
return model_patcher, vae_patcher
|
||||
|
||||
|
||||
class FakeObject(torch.nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
return
|
||||
66
modules_forge/initialization.py
Normal file
66
modules_forge/initialization.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def initialize_forge():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--no-vram", action="store_true")
|
||||
parser.add_argument("--normal-vram", action="store_true")
|
||||
parser.add_argument("--high-vram", action="store_true")
|
||||
parser.add_argument("--always-vram", action="store_true")
|
||||
vram_args = vars(parser.parse_known_args()[0])
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--attention-split", action="store_true")
|
||||
attn_group.add_argument("--attention-quad", action="store_true")
|
||||
attn_group.add_argument("--attention-pytorch", action="store_true")
|
||||
parser.add_argument("--disable-xformers", action="store_true")
|
||||
fpte_group = parser.add_mutually_exclusive_group()
|
||||
fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp16", action="store_true")
|
||||
fpte_group.add_argument("--clip-in-fp32", action="store_true")
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--all-in-fp32", action="store_true")
|
||||
fp_group.add_argument("--all-in-fp16", action="store_true")
|
||||
fpunet_group = parser.add_mutually_exclusive_group()
|
||||
fpunet_group.add_argument("--unet-in-bf16", action="store_true")
|
||||
fpunet_group.add_argument("--unet-in-fp16", action="store_true")
|
||||
fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true")
|
||||
fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true")
|
||||
fpvae_group = parser.add_mutually_exclusive_group()
|
||||
fpvae_group.add_argument("--vae-in-fp16", action="store_true")
|
||||
fpvae_group.add_argument("--vae-in-fp32", action="store_true")
|
||||
fpvae_group.add_argument("--vae-in-bf16", action="store_true")
|
||||
other_args = vars(parser.parse_known_args()[0])
|
||||
|
||||
from ldm_patched.modules.args_parser import args
|
||||
|
||||
args.always_cpu = False
|
||||
args.always_gpu = False
|
||||
args.always_high_vram = False
|
||||
args.always_low_vram = False
|
||||
args.always_no_vram = False
|
||||
args.always_offload_from_vram = True
|
||||
args.async_cuda_allocation = False
|
||||
args.disable_async_cuda_allocation = True
|
||||
|
||||
if vram_args['no_vram']:
|
||||
args.always_cpu = True
|
||||
|
||||
if vram_args['always_vram']:
|
||||
args.always_gpu = True
|
||||
|
||||
if vram_args['high_vram']:
|
||||
args.always_offload_from_vram = False
|
||||
|
||||
for k, v in other_args.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
import torch
|
||||
|
||||
device = model_management.get_torch_device()
|
||||
torch.zeros((1, 1)).to(device, torch.float32)
|
||||
model_management.soft_empty_cache()
|
||||
return
|
||||
19
modules_forge/ops.py
Normal file
19
modules_forge/ops.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import torch
|
||||
import contextlib
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_patched_ops(operations):
|
||||
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
|
||||
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
|
||||
|
||||
try:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, getattr(operations, op_name))
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
for op_name in op_names:
|
||||
setattr(torch.nn, op_name, backups[op_name])
|
||||
return
|
||||
Reference in New Issue
Block a user