Merge branch 'main' into dev

This commit is contained in:
lllyasviel
2024-01-27 15:48:30 -08:00
committed by GitHub
66 changed files with 2746 additions and 2892 deletions

View File

@@ -5,6 +5,7 @@ import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules.shared import opts, state
from ldm_patched.modules import model_management
import k_diffusion.sampling
@@ -39,9 +40,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt):
approximation = approximation_indexes.get(opts.show_progress_type, 0)
from modules import lowvram
if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full:
if approximation == 0:
approximation = 1
if approximation == 2:
@@ -54,8 +53,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32
x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype))
x_sample = model.decode_first_stage(sample)
return x_sample
@@ -71,7 +69,6 @@ def single_sample_to_image(sample, approximation=None):
def decode_first_stage(model, x):
x = x.to(devices.dtype_vae)
approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0)
return samples_to_images_tensor(x, approx_index, model)
@@ -95,7 +92,6 @@ def images_tensor_to_samples(image, approximation=None, model=None):
else:
if model is None:
model = shared.sd_model
model.first_stage_model.to(devices.dtype_vae)
image = image.to(shared.device, dtype=devices.dtype_vae)
image = image * 2 - 1
@@ -155,7 +151,7 @@ def replace_torchsde_browinan():
replace_torchsde_browinan()
def apply_refiner(cfg_denoiser):
def apply_refiner(cfg_denoiser, x):
completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps
refiner_switch_at = cfg_denoiser.p.refiner_switch_at
refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info
@@ -184,10 +180,17 @@ def apply_refiner(cfg_denoiser):
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
refiner = sd_models.model_data.get_sd_model()
devices.torch_gc()
cfg_denoiser.p.setup_conds()
cfg_denoiser.update_inner_model()
inference_memory = refiner.current_controlnet_required_memory
unet_patcher = refiner.forge_objects.unet
model_management.load_models_gpu(
[unet_patcher],
unet_patcher.memory_required([x.shape[0]] + list(x.shape[1:])) + inference_memory)
return True