Update webui.py

i

Update initialization.py

initialization

initialization

Update initialization.py

i

i

Update sd_samplers_common.py

Update sd_hijack.py

i

Update sd_models.py

Update sd_models.py

Update forge_loader.py

Update sd_models.py

i

Update sd_model.py

i

Update sd_models.py

Create sd_model.py

i

i

Update sd_models.py

i

Update sd_models.py

Update sd_models.py

i

i

Update sd_samplers_common.py

i

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update sd_samplers_common.py

i

Update shared_options.py

Update prompt_parser.py

Update sd_hijack_unet.py

i

Update sd_models.py

Update sd_models.py

Update sd_models.py

Update devices.py

i

Update sd_vae.py

Update sd_models.py

Update processing.py

Update ui_settings.py

Update sd_models_xl.py

i

i

Update sd_samplers_kdiffusion.py

Update sd_samplers_timesteps.py

Update ui_settings.py

Update cmd_args.py

Update cmd_args.py

Update initialization.py

Update shared_options.py

Update initialization.py

Update shared_options.py

i

Update cmd_args.py

Update initialization.py

Update initialization.py

Update initialization.py

Update cmd_args.py

Update cmd_args.py

Update sd_hijack.py
This commit is contained in:
lllyasviel
2024-01-16 02:33:39 -08:00
parent 7cb6178d47
commit b731bb860c
16 changed files with 250 additions and 216 deletions

View File

@@ -0,0 +1,42 @@
import torch
from ldm_patched.modules import model_management
from ldm_patched.modules import model_detection
from ldm_patched.modules.sd import VAE
import ldm_patched.modules.model_patcher
import ldm_patched.modules.utils
def load_unet_and_vae(sd):
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
model_config.set_manual_cast(manual_cast_dtype)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of")
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, "model.diffusion_model.", device=initial_load_device)
model.load_model_weights(sd, "model.diffusion_model.")
model_patcher = ldm_patched.modules.model_patcher.ModelPatcher(model,
load_device=load_device,
offload_device=model_management.unet_offload_device(),
current_device=initial_load_device)
vae_sd = ldm_patched.modules.utils.state_dict_prefix_replace(sd, {"first_stage_model.": ""}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
vae_patcher = VAE(sd=vae_sd)
return model_patcher, vae_patcher
class FakeObject(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
return

View File

@@ -0,0 +1,66 @@
import argparse
def initialize_forge():
parser = argparse.ArgumentParser()
parser.add_argument("--no-vram", action="store_true")
parser.add_argument("--normal-vram", action="store_true")
parser.add_argument("--high-vram", action="store_true")
parser.add_argument("--always-vram", action="store_true")
vram_args = vars(parser.parse_known_args()[0])
parser = argparse.ArgumentParser()
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--attention-split", action="store_true")
attn_group.add_argument("--attention-quad", action="store_true")
attn_group.add_argument("--attention-pytorch", action="store_true")
parser.add_argument("--disable-xformers", action="store_true")
fpte_group = parser.add_mutually_exclusive_group()
fpte_group.add_argument("--clip-in-fp8-e4m3fn", action="store_true")
fpte_group.add_argument("--clip-in-fp8-e5m2", action="store_true")
fpte_group.add_argument("--clip-in-fp16", action="store_true")
fpte_group.add_argument("--clip-in-fp32", action="store_true")
fp_group = parser.add_mutually_exclusive_group()
fp_group.add_argument("--all-in-fp32", action="store_true")
fp_group.add_argument("--all-in-fp16", action="store_true")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--unet-in-bf16", action="store_true")
fpunet_group.add_argument("--unet-in-fp16", action="store_true")
fpunet_group.add_argument("--unet-in-fp8-e4m3fn", action="store_true")
fpunet_group.add_argument("--unet-in-fp8-e5m2", action="store_true")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--vae-in-fp16", action="store_true")
fpvae_group.add_argument("--vae-in-fp32", action="store_true")
fpvae_group.add_argument("--vae-in-bf16", action="store_true")
other_args = vars(parser.parse_known_args()[0])
from ldm_patched.modules.args_parser import args
args.always_cpu = False
args.always_gpu = False
args.always_high_vram = False
args.always_low_vram = False
args.always_no_vram = False
args.always_offload_from_vram = True
args.async_cuda_allocation = False
args.disable_async_cuda_allocation = True
if vram_args['no_vram']:
args.always_cpu = True
if vram_args['always_vram']:
args.always_gpu = True
if vram_args['high_vram']:
args.always_offload_from_vram = False
for k, v in other_args.items():
setattr(args, k, v)
import ldm_patched.modules.model_management as model_management
import torch
device = model_management.get_torch_device()
torch.zeros((1, 1)).to(device, torch.float32)
model_management.soft_empty_cache()
return

19
modules_forge/ops.py Normal file
View File

@@ -0,0 +1,19 @@
import torch
import contextlib
@contextlib.contextmanager
def use_patched_ops(operations):
op_names = ['Linear', 'Conv2d', 'Conv3d', 'GroupNorm', 'LayerNorm']
backups = {op_name: getattr(torch.nn, op_name) for op_name in op_names}
try:
for op_name in op_names:
setattr(torch.nn, op_name, getattr(operations, op_name))
yield
finally:
for op_name in op_names:
setattr(torch.nn, op_name, backups[op_name])
return