mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
move to new backend - part 2
This commit is contained in:
@@ -1,22 +1,21 @@
|
|||||||
import torch
|
import torch
|
||||||
import ldm_patched.modules.ops as ops
|
from backend import operations, memory_management
|
||||||
|
from backend.patcher.base import ModelPatcher
|
||||||
|
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
|
||||||
from ldm_patched.modules import model_management
|
|
||||||
from transformers import modeling_utils
|
from transformers import modeling_utils
|
||||||
|
|
||||||
|
|
||||||
class DiffusersModelPatcher:
|
class DiffusersModelPatcher:
|
||||||
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
|
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
|
||||||
load_device = model_management.get_torch_device()
|
load_device = memory_management.get_torch_device()
|
||||||
offload_device = torch.device("cpu")
|
offload_device = torch.device("cpu")
|
||||||
|
|
||||||
if not model_management.should_use_fp16(device=load_device):
|
if not memory_management.should_use_fp16(device=load_device):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
with ops.use_patched_ops(ops.manual_cast):
|
with operations.using_forge_operations():
|
||||||
with modeling_utils.no_init_weights():
|
with modeling_utils.no_init_weights():
|
||||||
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
|
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
@@ -41,7 +40,7 @@ class DiffusersModelPatcher:
|
|||||||
def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height):
|
def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height):
|
||||||
area = 2 * batchsize * latent_width * latent_height
|
area = 2 * batchsize * latent_width * latent_height
|
||||||
inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
|
||||||
model_management.load_models_gpu(
|
memory_management.load_models_gpu(
|
||||||
models=[self.patcher],
|
models=[self.patcher],
|
||||||
memory_required=inference_memory
|
memory_required=inference_memory
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,22 +1,23 @@
|
|||||||
from modules import sd_samplers_kdiffusion, sd_samplers_common
|
# from modules import sd_samplers_kdiffusion, sd_samplers_common
|
||||||
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
|
# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
|
||||||
|
#
|
||||||
|
#
|
||||||
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
|
||||||
def __init__(self, sd_model, sampler_name):
|
# def __init__(self, sd_model, sampler_name):
|
||||||
self.sampler_name = sampler_name
|
# self.sampler_name = sampler_name
|
||||||
self.unet = sd_model.forge_objects.unet
|
# self.unet = sd_model.forge_objects.unet
|
||||||
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
|
||||||
super().__init__(sampler_function, sd_model, None)
|
# super().__init__(sampler_function, sd_model, None)
|
||||||
|
#
|
||||||
|
#
|
||||||
def build_constructor(sampler_name):
|
# def build_constructor(sampler_name):
|
||||||
def constructor(m):
|
# def constructor(m):
|
||||||
return AlterSampler(m, sampler_name)
|
# return AlterSampler(m, sampler_name)
|
||||||
|
#
|
||||||
return constructor
|
# return constructor
|
||||||
|
#
|
||||||
|
#
|
||||||
|
|
||||||
samplers_data_alter = [
|
samplers_data_alter = [
|
||||||
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
|
# sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
|
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
from modules import sd_models
|
from modules import sd_models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
|
|
||||||
@@ -9,7 +9,7 @@ def move_clip_to_gpu():
|
|||||||
print('Error: CLIP called before SD is loaded!')
|
print('Error: CLIP called before SD is loaded!')
|
||||||
return
|
return
|
||||||
|
|
||||||
model_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher)
|
memory_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management, utils
|
||||||
from ldm_patched.modules import model_detection
|
|
||||||
|
|
||||||
from ldm_patched.modules.sd import VAE, load_model_weights
|
|
||||||
from backend.patcher.clip import CLIP
|
from backend.patcher.clip import CLIP
|
||||||
from backend.patcher.vae import VAE
|
from backend.patcher.vae import VAE
|
||||||
import ldm_patched.modules.model_patcher
|
from backend.patcher.base import ModelPatcher
|
||||||
import ldm_patched.modules.utils
|
|
||||||
import ldm_patched.modules.clip_vision
|
|
||||||
import backend.nn.unet
|
import backend.nn.unet
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -20,7 +15,6 @@ from modules.sd_models_xl import extend_sdxl
|
|||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from modules_forge import forge_clip
|
from modules_forge import forge_clip
|
||||||
from modules_forge.unet_patcher import UnetPatcher
|
from modules_forge.unet_patcher import UnetPatcher
|
||||||
from ldm_patched.modules.model_base import model_sampling, ModelType
|
|
||||||
from backend.loader import load_huggingface_components
|
from backend.loader import load_huggingface_components
|
||||||
from backend.modules.k_model import KModel
|
from backend.modules.k_model import KModel
|
||||||
|
|
||||||
@@ -85,13 +79,13 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
model_patcher = None
|
model_patcher = None
|
||||||
clip_target = None
|
clip_target = None
|
||||||
|
|
||||||
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.")
|
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters)
|
unet_dtype = memory_management.unet_dtype(model_params=parameters)
|
||||||
load_device = model_management.get_torch_device()
|
load_device = memory_management.get_torch_device()
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
|
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
|
||||||
|
|
||||||
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
backend.nn.unet.unet_initial_device = initial_load_device
|
backend.nn.unet.unet_initial_device = initial_load_device
|
||||||
backend.nn.unet.unet_initial_dtype = unet_dtype
|
backend.nn.unet.unet_initial_dtype = unet_dtype
|
||||||
|
|
||||||
@@ -101,7 +95,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
|
||||||
k_model.to(device=initial_load_device, dtype=unet_dtype)
|
k_model.to(device=initial_load_device, dtype=unet_dtype)
|
||||||
model_patcher = UnetPatcher(k_model, load_device=load_device,
|
model_patcher = UnetPatcher(k_model, load_device=load_device,
|
||||||
offload_device=model_management.unet_offload_device(),
|
offload_device=memory_management.unet_offload_device(),
|
||||||
current_device=initial_load_device)
|
current_device=initial_load_device)
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
|
|||||||
@@ -6,17 +6,17 @@ import random
|
|||||||
import string
|
import string
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
|
|
||||||
|
|
||||||
def prepare_free_memory(aggressive=False):
|
def prepare_free_memory(aggressive=False):
|
||||||
if aggressive:
|
if aggressive:
|
||||||
model_management.unload_all_models()
|
memory_management.unload_all_models()
|
||||||
print('Cleanup all memory.')
|
print('Cleanup all memory.')
|
||||||
return
|
return
|
||||||
|
|
||||||
model_management.free_memory(memory_required=model_management.minimum_inference_memory(),
|
memory_management.free_memory(memory_required=memory_management.minimum_inference_memory(),
|
||||||
device=model_management.get_torch_device())
|
device=memory_management.get_torch_device())
|
||||||
print('Cleanup minimal inference memory.')
|
print('Cleanup minimal inference memory.')
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -151,6 +151,6 @@ def resize_image_with_pad(img, resolution):
|
|||||||
|
|
||||||
|
|
||||||
def lazy_memory_management(model):
|
def lazy_memory_management(model):
|
||||||
required_memory = model_management.module_size(model) + model_management.minimum_inference_memory()
|
required_memory = memory_management.module_size(model) + memory_management.minimum_inference_memory()
|
||||||
model_management.free_memory(required_memory, device=model_management.get_torch_device())
|
memory_management.free_memory(required_memory, device=memory_management.get_torch_device())
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -35,15 +35,13 @@ def initialize_forge():
|
|||||||
print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, '
|
print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, '
|
||||||
f'please use --always-offload-from-vram')
|
f'please use --always-offload-from-vram')
|
||||||
|
|
||||||
from ldm_patched.modules import args_parser
|
from backend.args import args
|
||||||
|
|
||||||
args_parser.args, _ = args_parser.parser.parse_known_args()
|
if args.gpu_device_id is not None:
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
|
||||||
|
print("Set device to:", args.gpu_device_id)
|
||||||
|
|
||||||
if args_parser.args.gpu_device_id is not None:
|
if args.cuda_malloc:
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args_parser.args.gpu_device_id)
|
|
||||||
print("Set device to:", args_parser.args.gpu_device_id)
|
|
||||||
|
|
||||||
if args_parser.args.cuda_malloc:
|
|
||||||
from modules_forge.cuda_malloc import try_cuda_malloc
|
from modules_forge.cuda_malloc import try_cuda_malloc
|
||||||
try_cuda_malloc()
|
try_cuda_malloc()
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import contextlib
|
import contextlib
|
||||||
from ldm_patched.modules import model_management
|
|
||||||
from ldm_patched.modules.ops import use_patched_ops
|
from backend import memory_management
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def automatic_memory_management():
|
def automatic_memory_management():
|
||||||
model_management.free_memory(
|
memory_management.free_memory(
|
||||||
memory_required=3 * 1024 * 1024 * 1024,
|
memory_required=3 * 1024 * 1024 * 1024,
|
||||||
device=model_management.get_torch_device()
|
device=memory_management.get_torch_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
module_list = []
|
module_list = []
|
||||||
@@ -39,7 +39,7 @@ def automatic_memory_management():
|
|||||||
for module in module_list:
|
for module in module_list:
|
||||||
module.cpu()
|
module.cpu()
|
||||||
|
|
||||||
model_management.soft_empty_cache()
|
memory_management.soft_empty_cache()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
|
|
||||||
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
|
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import ldm_patched.modules.utils
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
from backend import utils
|
||||||
from modules.paths_internal import models_path
|
from modules.paths_internal import models_path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ def add_supported_control_model(control_model):
|
|||||||
|
|
||||||
def try_load_supported_control_model(ckpt_path):
|
def try_load_supported_control_model(ckpt_path):
|
||||||
global supported_control_models
|
global supported_control_models
|
||||||
state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
|
state_dict = utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
for supported_type in supported_control_models:
|
for supported_type in supported_control_models:
|
||||||
state_dict_copy = {k: v for k, v in state_dict.items()}
|
state_dict_copy = {k: v for k, v in state_dict.items()}
|
||||||
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)
|
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ import cv2
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from modules_forge.shared import add_supported_preprocessor, preprocessor_dir
|
from modules_forge.shared import add_supported_preprocessor, preprocessor_dir
|
||||||
from ldm_patched.modules import model_management
|
from backend import memory_management
|
||||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
from backend.patcher.base import ModelPatcher
|
||||||
|
from backend.patcher import clipvision
|
||||||
from modules_forge.forge_util import resize_image_with_pad
|
from modules_forge.forge_util import resize_image_with_pad
|
||||||
import ldm_patched.modules.clip_vision
|
|
||||||
from modules.modelloader import load_file_from_url
|
from modules.modelloader import load_file_from_url
|
||||||
from modules_forge.forge_util import numpy_to_pytorch
|
from modules_forge.forge_util import numpy_to_pytorch
|
||||||
|
|
||||||
@@ -37,12 +37,12 @@ class Preprocessor:
|
|||||||
|
|
||||||
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
|
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
|
||||||
if load_device is None:
|
if load_device is None:
|
||||||
load_device = model_management.get_torch_device()
|
load_device = memory_management.get_torch_device()
|
||||||
|
|
||||||
if offload_device is None:
|
if offload_device is None:
|
||||||
offload_device = torch.device('cpu')
|
offload_device = torch.device('cpu')
|
||||||
|
|
||||||
if not model_management.should_use_fp16(load_device):
|
if not memory_management.should_use_fp16(load_device):
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -53,7 +53,7 @@ class Preprocessor:
|
|||||||
return self.model_patcher
|
return self.model_patcher
|
||||||
|
|
||||||
def move_all_model_patchers_to_gpu(self):
|
def move_all_model_patchers_to_gpu(self):
|
||||||
model_management.load_models_gpu([self.model_patcher])
|
memory_management.load_models_gpu([self.model_patcher])
|
||||||
return
|
return
|
||||||
|
|
||||||
def send_tensor_to_model_device(self, x):
|
def send_tensor_to_model_device(self, x):
|
||||||
@@ -127,7 +127,7 @@ class PreprocessorClipVision(Preprocessor):
|
|||||||
if ckpt_path in PreprocessorClipVision.global_cache:
|
if ckpt_path in PreprocessorClipVision.global_cache:
|
||||||
self.clipvision = PreprocessorClipVision.global_cache[ckpt_path]
|
self.clipvision = PreprocessorClipVision.global_cache[ckpt_path]
|
||||||
else:
|
else:
|
||||||
self.clipvision = ldm_patched.modules.clip_vision.load(ckpt_path)
|
self.clipvision = clipvision.load(ckpt_path)
|
||||||
PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision
|
PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision
|
||||||
|
|
||||||
return self.clipvision
|
return self.clipvision
|
||||||
|
|||||||
Reference in New Issue
Block a user