Update webui.py

i

Update initialization.py

initialization

initialization

Update initialization.py

i

i

Update sd_samplers_common.py

Update sd_hijack.py

i

Update sd_models.py

Update sd_models.py

Update forge_loader.py

Update sd_models.py

i

Update sd_model.py

i

Update sd_models.py

Create sd_model.py

i

i

Update sd_models.py

i

Update sd_models.py

Update sd_models.py

i

i

Update sd_samplers_common.py

i

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

i

Update shared_options.py

Update prompt_parser.py

Update sd_hijack_unet.py

i

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update devices.py

i

Update sd_vae.py

Update sd_models.py

Update processing.py

Update ui_settings.py

Update sd_models_xl.py

i

i

Update sd_samplers_kdiffusion.py

Update sd_samplers_timesteps.py

Update ui_settings.py

Update cmd_args.py

Update cmd_args.py

Update initialization.py

Update shared_options.py

Update initialization.py

Update shared_options.py

i

Update cmd_args.py

Update initialization.py

Update initialization.py

Update initialization.py

Update cmd_args.py

Update cmd_args.py

Update sd_hijack.py
This commit is contained in:
lllyasviel
2024-01-16 02:33:39 -08:00
parent 7cb6178d47
commit b731bb860c
16 changed files with 250 additions and 216 deletions

View File

@@ -179,26 +179,11 @@ def manual_cast(target_dtype):
def autocast(disable=False):
if disable:
return contextlib.nullcontext()
if fp8 and device==cpu:
return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)
if fp8 and dtype_inference == torch.float32:
return manual_cast(dtype)
if dtype == torch.float32 or dtype_inference == torch.float32:
return contextlib.nullcontext()
if has_xpu() or has_mps() or cuda_no_autocast():
return manual_cast(dtype)
return torch.autocast("cuda")
return contextlib.nullcontext()
def without_autocast(disable=False):
return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()
return contextlib.nullcontext()
class NansException(Exception):

View File

@@ -12,6 +12,10 @@ def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules_forge.initialization import initialize_forge
initialize_forge()
startup_timer.record("initialize forge")
import torch # noqa: F401
startup_timer.record("import torch")
import pytorch_lightning # noqa: F401

View File

@@ -16,7 +16,7 @@ from skimage import exposure
from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules import devices, prompt_parser, masking, sd_samplers, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
@@ -627,44 +627,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
if shared.opts.auto_vae_precision_bfloat16:
autofix_dtype = torch.bfloat16
autofix_dtype_text = "bfloat16"
autofix_dtype_setting = "Automatically convert VAE to bfloat16"
autofix_dtype_comment = ""
elif shared.opts.auto_vae_precision:
autofix_dtype = torch.float32
autofix_dtype_text = "32-bit float"
autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
else:
raise e
if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)
sample = decode_first_stage(model, batch[i:i + 1])[0]
if target_device is not None:
sample = sample.to(target_device)
samples.append(sample)
samples.append(sample.to(target_device))
return samples
@@ -940,8 +903,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
if p.scripts is not None:
ps = scripts.PostSampleArgs(samples_ddim)
@@ -960,9 +922,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
state.nextjob()
@@ -1255,7 +1214,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
image = np.array(self.firstpass_image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
image = torch.from_numpy(np.expand_dims(image, axis=0))
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
@@ -1339,7 +1298,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images))
decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
decoded_samples = decoded_samples.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
@@ -1444,7 +1403,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()
elif lowvram.is_enabled(shared.sd_model) and shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
elif shared.sd_model.sd_checkpoint_info == sd_models.select_checkpoint(): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
@@ -1631,7 +1590,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image.to(shared.device, dtype=torch.float32)
if opts.sd_vae_encode_method != 'Full':
self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method

View File

@@ -276,6 +276,12 @@ class DictWithShape(dict):
def shape(self):
return self["crossattn"].shape
def to(self, *args, **kwargs):
for k in self.keys():
if isinstance(self[k], torch.Tensor):
self[k] = self[k].to(*args, **kwargs)
return self
def reconstruct_cond_batch(c: list[list[ScheduledPromptConditioning]], current_step):
param = c[0][0].cond

View File

@@ -3,6 +3,7 @@ from collections import namedtuple
import torch
import ldm_patched.modules.model_management as model_management
from modules import prompt_parser, devices, sd_hijack
from modules.shared import opts
@@ -210,6 +211,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"
"""
if hasattr(self.wrapped, 'patcher'):
model_management.load_model_gpu(self.wrapped.patcher)
if opts.use_old_emphasis_implementation:
import modules.sd_hijack_clip_old
return modules.sd_hijack_clip_old.forward_old(self, texts)

View File

@@ -13,10 +13,16 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, sd_hijack, patches
from modules.timer import Timer
import tomesd
import numpy as np
from modules_forge import forge_loader
import modules_forge.ops as forge_ops
from ldm_patched.modules.ops import manual_cast
from ldm_patched.modules import model_management as model_management
import ldm_patched.modules.model_patcher
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
@@ -366,10 +372,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_model_hash = checkpoint_info.calculate_shorthash()
timer.record("calculate hash")
if devices.fp8:
# prevent model to load state dict in fp8
model.half()
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
@@ -379,13 +381,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in model.state_dict().keys()
if model.is_sdxl:
sd_models_xl.extend_sdxl(model)
if model.is_ssd:
sd_hijack.model_hijack.convert_sdxl_to_ssd(model)
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
checkpoints_loaded[checkpoint_info] = state_dict.copy()
@@ -395,65 +394,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
timer.record("apply channels_last")
if shared.cmd_opts.no_half:
model.float()
model.alphas_cumprod_original = model.alphas_cumprod
devices.dtype_unet = torch.float32
timer.record("apply float()")
else:
vae = model.first_stage_model
depth_model = getattr(model, 'depth_model', None)
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
# with --upcast-sampling, don't convert the depth model weights to float16
if shared.cmd_opts.upcast_sampling and depth_model:
model.depth_model = None
alphas_cumprod = model.alphas_cumprod
model.alphas_cumprod = None
model.half()
model.alphas_cumprod = alphas_cumprod
model.alphas_cumprod_original = alphas_cumprod
model.first_stage_model = vae
if depth_model:
model.depth_model = depth_model
devices.dtype_unet = torch.float16
timer.record("apply half()")
for module in model.modules():
if hasattr(module, 'fp16_weight'):
del module.fp16_weight
if hasattr(module, 'fp16_bias'):
del module.fp16_bias
if check_fp8(model):
devices.fp8 = True
first_stage = model.first_stage_model
model.first_stage_model = None
for module in model.modules():
if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
if shared.opts.cache_fp16_weight:
module.fp16_weight = module.weight.data.clone().cpu().half()
if module.bias is not None:
module.fp16_bias = module.bias.data.clone().cpu().half()
module.to(torch.float8_e4m3fn)
model.first_stage_model = first_stage
timer.record("apply fp8")
else:
devices.fp8 = False
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae)
timer.record("apply dtype to VAE")
# clean up cache if limit is reached
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False)
@@ -614,34 +554,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
def send_model_to_cpu(m):
if m.lowvram:
lowvram.send_everything_to_cpu()
else:
m.to(devices.cpu)
devices.torch_gc()
def model_target_device(m):
if lowvram.is_needed(m):
return devices.cpu
else:
return devices.device
def send_model_to_device(m):
lowvram.apply(m)
if not m.lowvram:
m.to(shared.device)
def send_model_to_trash(m):
m.to(device="meta")
devices.torch_gc()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -649,7 +561,6 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer = Timer()
if model_data.sd_model:
send_model_to_trash(model_data.sd_model)
model_data.sd_model = None
devices.torch_gc()
@@ -670,12 +581,21 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("load config")
if hasattr(sd_config.model.params, 'network_config'):
sd_config.model.params.network_config.target = 'modules_forge.forge_loader.FakeObject'
if hasattr(sd_config.model.params, 'unet_config'):
sd_config.model.params.unet_config.target = 'modules_forge.forge_loader.FakeObject'
if hasattr(sd_config.model.params, 'first_stage_config'):
sd_config.model.params.first_stage_config.target = 'modules_forge.forge_loader.FakeObject'
print(f"Creating model from config: {checkpoint_config}")
sd_model = None
try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
with sd_disable_initialization.InitializeOnMeta():
with forge_ops.use_patched_ops(manual_cast):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
@@ -684,28 +604,45 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
with sd_disable_initialization.InitializeOnMeta():
with forge_ops.use_patched_ops(manual_cast):
sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config
timer.record("create model")
if shared.cmd_opts.no_half:
weight_dtype_conversion = None
else:
weight_dtype_conversion = {
'first_stage_model': None,
'alphas_cumprod': None,
'': torch.float16,
}
state_dict_for_a1111 = {k: v for k, v in state_dict.items() if not k.startswith('model.diffusion_model.') and not k.startswith('first_stage_model.')}
state_dict_for_forge = {k: v for k, v in state_dict.items()}
del state_dict
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
unet_patcher, vae_patcher = forge_loader.load_unet_and_vae(state_dict_for_forge)
sd_model.first_stage_model = vae_patcher.first_stage_model
sd_model.model.diffusion_model = unet_patcher.model.diffusion_model
sd_model.unet_patcher = unet_patcher
sd_model.model.diffusion_model.patcher = unet_patcher
sd_model.vae_patcher = vae_patcher
sd_model.first_stage_model.patcher = vae_patcher
timer.record("create unet patcher")
del state_dict_for_forge
load_model_weights(sd_model, checkpoint_info, state_dict_for_a1111, timer)
del state_dict_for_a1111
timer.record("load weights from state dict")
send_model_to_device(sd_model)
timer.record("move model to device")
current_clip = sd_model.conditioner if hasattr(sd_model, 'conditioner') else sd_model.cond_stage_model
clip_load_device = model_management.text_encoder_device()
clip_offload_device = model_management.text_encoder_offload_device()
clip_dtype = model_management.text_encoder_dtype()
current_clip.to(clip_dtype)
clip_patcher = ldm_patched.modules.model_patcher.ModelPatcher(
current_clip,
load_device=clip_load_device,
offload_device=clip_offload_device
)
sd_model.clip_patcher = clip_patcher
current_clip.patcher = clip_patcher
timer.record("create clip patcher")
sd_hijack.model_hijack.hijack(sd_model)
@@ -752,17 +689,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
model_data.loaded_sd_models.pop()
send_model_to_trash(loaded_model)
timer.record("send model to trash")
if shared.opts.sd_checkpoints_keep_in_cpu:
send_model_to_cpu(sd_model)
timer.record("send model to cpu")
if already_loaded is not None:
send_model_to_device(already_loaded)
timer.record("send model to device")
model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
@@ -816,7 +744,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
if sd_model is not None:
sd_unet.apply_unet("None")
send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
@@ -827,7 +754,7 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
if sd_model is None or checkpoint_config != sd_model.used_config:
if sd_model is not None:
send_model_to_trash(sd_model)
sd_model = None
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
@@ -857,12 +784,6 @@ def reload_model_weights(sd_model=None, info=None, forced_reload=False):
return sd_model
def unload_model_weights(sd_model=None, info=None):
send_model_to_cpu(sd_model or shared.sd_model)
return sd_model
def apply_token_merging(sd_model, token_merging_ratio):
"""
Applies speed and memory optimizations from tomesd.

View File

@@ -8,8 +8,12 @@ import sgm.modules.diffusionmodules.discretizer
from modules import devices, shared, prompt_parser
from modules import torch_utils
import ldm_patched.modules.model_management as model_management
def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):
model_management.load_model_gpu(self.clip_patcher)
for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0
@@ -18,7 +22,7 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
devices_args = dict(device=devices.device, dtype=devices.dtype)
devices_args = dict(device=self.clip_patcher.current_device, dtype=model_management.text_encoder_dtype())
sdxl_conds = {
"txt": batch,
@@ -35,11 +39,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):
sd = self.model.state_dict()
diffusion_model_input = sd.get('diffusion_model.input_blocks.0.0.weight', None)
if diffusion_model_input is not None:
if diffusion_model_input.shape[1] == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
if self.model.diffusion_model.in_channels == 9:
x = torch.cat([x] + cond['c_concat'], dim=1)
return self.model(x, t, cond)

View File

@@ -173,6 +173,15 @@ class CFGDenoiser(torch.nn.Module):
uncond = pad_cond(uncond, num_repeats, empty)
self.padded_cond_uncond = True
unet_dtype = self.inner_model.inner_model.unet_patcher.model.model_config.unet_config['dtype']
x_input_dtype = x_in.dtype
x_in = x_in.to(unet_dtype)
sigma_in = sigma_in.to(unet_dtype)
image_cond_in = image_cond_in.to(unet_dtype)
tensor = tensor.to(unet_dtype)
uncond = uncond.to(unet_dtype)
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
@@ -211,6 +220,8 @@ class CFGDenoiser(torch.nn.Module):
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
x_out = x_out.to(x_input_dtype)
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)

View File

@@ -39,9 +39,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
from modules import lowvram
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
if approximation == 0:
approximation = 1
if approximation == 2:
@@ -54,8 +52,8 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
sample = model.unet_patcher.model.model_config.latent_format.process_out(sample)
x_sample = model.vae_patcher.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return x_sample
@@ -71,7 +69,6 @@ def single_sample_to_image(sample, approximation=None):
def decode_first_stage(model, x):
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)

View File

@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
import ldm_patched.modules.model_management
samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
@@ -139,6 +141,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return sigmas
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
inference_memory = 0
unet_patcher = self.model_wrap.inner_model.unet_patcher
ldm_patched.modules.model_management.load_models_gpu(
[unet_patcher],
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
sigmas = self.get_sigmas(p, steps)
@@ -193,6 +201,12 @@ class KDiffusionSampler(sd_samplers_common.Sampler):
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
inference_memory = 0
unet_patcher = self.model_wrap.inner_model.unet_patcher
ldm_patched.modules.model_management.load_models_gpu(
[unet_patcher],
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
steps = steps or p.steps
sigmas = self.get_sigmas(p, steps)

View File

@@ -34,7 +34,7 @@ class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
def sigma_to_t(self, sigma, quantize=None):
log_sigma = sigma.log()
dists = log_sigma - self.log_sigmas[:, None]
dists = log_sigma - self.log_sigmas.to(sigma)[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)

View File

@@ -7,6 +7,8 @@ from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback
from modules.shared import opts
import modules.shared as shared
import ldm_patched.modules.model_management
samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
@@ -95,6 +97,12 @@ class CompVisSampler(sd_samplers_common.Sampler):
return timesteps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
inference_memory = 0
unet_patcher = self.model_wrap.inner_model.unet_patcher
ldm_patched.modules.model_management.load_models_gpu(
[unet_patcher],
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
timesteps = self.get_timesteps(p, steps)
@@ -139,6 +147,12 @@ class CompVisSampler(sd_samplers_common.Sampler):
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
inference_memory = 0
unet_patcher = self.model_wrap.inner_model.unet_patcher
ldm_patched.modules.model_management.load_models_gpu(
[unet_patcher],
unet_patcher.memory_required([x.shape[0] * 2] + list(x.shape[1:])) + inference_memory)
steps = steps or p.steps
timesteps = self.get_timesteps(p, steps)

View File

@@ -2,7 +2,7 @@ import os
import collections
from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, sd_hijack, hashes
import glob
from copy import deepcopy
@@ -237,7 +237,6 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
model.first_stage_model.load_state_dict(vae_dict_1)
model.first_stage_model.to(devices.dtype_vae)
def clear_loaded_vae():
@@ -263,20 +262,12 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
if sd_model.lowvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
if not sd_model.lowvram:
sd_model.to(devices.device)
script_callbacks.model_loaded_callback(sd_model)
print("VAE weights loaded.")

View File

@@ -0,0 +1,42 @@
import torch
from ldm_patched.modules import model_management
from ldm_patched.modules import model_detection
from ldm_patched.modules.sd import VAE
import ldm_patched.modules.model_patcher
import ldm_patched.modules.utils
def load_unet_and_vae(sd):
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of")
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, "model.diffusion_model.", device=initial_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model,
load_device=load_device,
offload_device=model_management.unet_offload_device(),
current_device=initial_load_device)
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae_patcher = VAE(sd=vae_sd)
return model_patcher, vae_patcher
class FakeObject(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
return

View File

@@ -0,0 +1,66 @@
import argparse
def initialize_forge():
parser = argparse.ArgumentParser()
parser.add_argument("--no-vram", action="store_true")
parser.add_argument("--normal-vram", action="store_true")
parser.add_argument("--high-vram", action="store_true")
parser.add_argument("--always-vram", action="store_true")
vram_args = vars(parser.parse_known_args()[0])
parser = argparse.ArgumentParser()
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--attention-split", action="store_true")
attn_group.add_argument("--attention-quad", action="store_true")
attn_group.add_argument("--attention-pytorch", action="store_true")
parser.add_argument("--disable-xformers", action="store_true")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true")
fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
fpte_group.add_argument("--clip-in-fp16", action="store_true")
fpte_group.add_argument("--clip-in-fp32", action="store_true")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--all-in-fp32", action="store_true")
fp_group.add_argument("--all-in-fp16", action="store_true")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--unet-in-bf16", action="store_true")
fpunet_group.add_argument("--unet-in-fp16", action="store_true")
fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true")
fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--vae-in-fp16", action="store_true")
fpvae_group.add_argument("--vae-in-fp32", action="store_true")
fpvae_group.add_argument("--vae-in-bf16", action="store_true")
other_args = vars(parser.parse_known_args()[0])
from ldm_patched.modules.args_parser import args
args.always_cpu = False
args.always_gpu = False
args.always_high_vram = False
args.always_low_vram = False
args.always_no_vram = False
args.always_offload_from_vram = True
args.async_cuda_allocation = False
args.disable_async_cuda_allocation = True
if vram_args['no_vram']:
args.always_cpu = True
if vram_args['always_vram']:
args.always_gpu = True
if vram_args['high_vram']:
args.always_offload_from_vram = False
for k, v in other_args.items():
setattr(args, k, v)
import ldm_patched.modules.model_management as model_management
import torch
device = model_management.get_torch_device()
torch.zeros((1, 1)).to(device, torch.float32)
model_management.soft_empty_cache()
return

19
modules_forge/ops.py Normal file
View File

@@ -0,0 +1,19 @@
import torch
import contextlib
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return