From 4add428e250a9b91d1f2b20a7219d51af67c13db Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Sat, 3 Aug 2024 15:10:37 -0700 Subject: [PATCH] move to new backend - part 2 --- modules_forge/diffusers_patcher.py | 13 ++++----- modules_forge/forge_alter_samplers.py | 39 +++++++++++++------------ modules_forge/forge_clip.py | 4 +-- modules_forge/forge_loader.py | 22 +++++--------- modules_forge/forge_util.py | 12 ++++---- modules_forge/initialization.py | 12 ++++---- modules_forge/ops.py | 10 +++---- modules_forge/shared.py | 4 +-- modules_forge/supported_preprocessor.py | 14 ++++----- 9 files changed, 61 insertions(+), 69 deletions(-) diff --git a/modules_forge/diffusers_patcher.py b/modules_forge/diffusers_patcher.py index e495330d..69f9a09b 100644 --- a/modules_forge/diffusers_patcher.py +++ b/modules_forge/diffusers_patcher.py @@ -1,22 +1,21 @@ import torch -import ldm_patched.modules.ops as ops +from backend import operations, memory_management +from backend.patcher.base import ModelPatcher -from ldm_patched.modules.model_patcher import ModelPatcher -from ldm_patched.modules import model_management from transformers import modeling_utils class DiffusersModelPatcher: def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs): - load_device = model_management.get_torch_device() + load_device = memory_management.get_torch_device() offload_device = torch.device("cpu") - if not model_management.should_use_fp16(device=load_device): + if not memory_management.should_use_fp16(device=load_device): dtype = torch.float32 self.dtype = dtype - with ops.use_patched_ops(ops.manual_cast): + with operations.using_forge_operations(): with modeling_utils.no_init_weights(): self.pipeline = pipeline_class.from_pretrained(*args, **kwargs) @@ -41,7 +40,7 @@ class DiffusersModelPatcher: def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height): area = 2 * batchsize * latent_width * latent_height inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) - model_management.load_models_gpu( + memory_management.load_models_gpu( models=[self.patcher], memory_required=inference_memory ) diff --git a/modules_forge/forge_alter_samplers.py b/modules_forge/forge_alter_samplers.py index 8316d322..1f474c67 100644 --- a/modules_forge/forge_alter_samplers.py +++ b/modules_forge/forge_alter_samplers.py @@ -1,22 +1,23 @@ -from modules import sd_samplers_kdiffusion, sd_samplers_common -from ldm_patched.k_diffusion import sampling as k_diffusion_sampling - - -class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): - def __init__(self, sd_model, sampler_name): - self.sampler_name = sampler_name - self.unet = sd_model.forge_objects.unet - sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) - super().__init__(sampler_function, sd_model, None) - - -def build_constructor(sampler_name): - def constructor(m): - return AlterSampler(m, sampler_name) - - return constructor - +# from modules import sd_samplers_kdiffusion, sd_samplers_common +# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling +# +# +# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): +# def __init__(self, sd_model, sampler_name): +# self.sampler_name = sampler_name +# self.unet = sd_model.forge_objects.unet +# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) +# super().__init__(sampler_function, sd_model, None) +# +# +# def build_constructor(sampler_name): +# def constructor(m): +# return AlterSampler(m, sampler_name) +# +# return constructor +# +# samplers_data_alter = [ - sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), + # sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), ] diff --git a/modules_forge/forge_clip.py b/modules_forge/forge_clip.py index 7c787c16..a9dccf50 100644 --- a/modules_forge/forge_clip.py +++ b/modules_forge/forge_clip.py @@ -1,5 +1,5 @@ from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords -from ldm_patched.modules import model_management +from backend import memory_management from modules import sd_models from modules.shared import opts @@ -9,7 +9,7 @@ def move_clip_to_gpu(): print('Error: CLIP called before SD is loaded!') return - model_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher) + memory_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher) return diff --git a/modules_forge/forge_loader.py b/modules_forge/forge_loader.py index cbe71e9f..fcc3345c 100644 --- a/modules_forge/forge_loader.py +++ b/modules_forge/forge_loader.py @@ -1,15 +1,10 @@ import torch import contextlib -from ldm_patched.modules import model_management -from ldm_patched.modules import model_detection - -from ldm_patched.modules.sd import VAE, load_model_weights +from backend import memory_management, utils from backend.patcher.clip import CLIP from backend.patcher.vae import VAE -import ldm_patched.modules.model_patcher -import ldm_patched.modules.utils -import ldm_patched.modules.clip_vision +from backend.patcher.base import ModelPatcher import backend.nn.unet from omegaconf import OmegaConf @@ -20,7 +15,6 @@ from modules.sd_models_xl import extend_sdxl from ldm.util import instantiate_from_config from modules_forge import forge_clip from modules_forge.unet_patcher import UnetPatcher -from ldm_patched.modules.model_base import model_sampling, ModelType from backend.loader import load_huggingface_components from backend.modules.k_model import KModel @@ -85,13 +79,13 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c model_patcher = None clip_target = None - parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") - unet_dtype = model_management.unet_dtype(model_params=parameters) - load_device = model_management.get_torch_device() - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + parameters = utils.calculate_parameters(sd, "model.diffusion_model.") + unet_dtype = memory_management.unet_dtype(model_params=parameters) + load_device = memory_management.get_torch_device() + manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype - initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) + initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype) backend.nn.unet.unet_initial_device = initial_load_device backend.nn.unet.unet_initial_dtype = unet_dtype @@ -101,7 +95,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype) k_model.to(device=initial_load_device, dtype=unet_dtype) model_patcher = UnetPatcher(k_model, load_device=load_device, - offload_device=model_management.unet_offload_device(), + offload_device=memory_management.unet_offload_device(), current_device=initial_load_device) if output_vae: diff --git a/modules_forge/forge_util.py b/modules_forge/forge_util.py index df47b571..ff78328a 100644 --- a/modules_forge/forge_util.py +++ b/modules_forge/forge_util.py @@ -6,17 +6,17 @@ import random import string import cv2 -from ldm_patched.modules import model_management +from backend import memory_management def prepare_free_memory(aggressive=False): if aggressive: - model_management.unload_all_models() + memory_management.unload_all_models() print('Cleanup all memory.') return - model_management.free_memory(memory_required=model_management.minimum_inference_memory(), - device=model_management.get_torch_device()) + memory_management.free_memory(memory_required=memory_management.minimum_inference_memory(), + device=memory_management.get_torch_device()) print('Cleanup minimal inference memory.') return @@ -151,6 +151,6 @@ def resize_image_with_pad(img, resolution): def lazy_memory_management(model): - required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() - model_management.free_memory(required_memory, device=model_management.get_torch_device()) + required_memory = memory_management.module_size(model) + memory_management.minimum_inference_memory() + memory_management.free_memory(required_memory, device=memory_management.get_torch_device()) return diff --git a/modules_forge/initialization.py b/modules_forge/initialization.py index 4fee214a..48e333e8 100644 --- a/modules_forge/initialization.py +++ b/modules_forge/initialization.py @@ -35,15 +35,13 @@ def initialize_forge(): print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, ' f'please use --always-offload-from-vram') - from ldm_patched.modules import args_parser + from backend.args import args - args_parser.args, _ = args_parser.parser.parse_known_args() + if args.gpu_device_id is not None: + os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id) + print("Set device to:", args.gpu_device_id) - if args_parser.args.gpu_device_id is not None: - os.environ['CUDA_VISIBLE_DEVICES'] = str(args_parser.args.gpu_device_id) - print("Set device to:", args_parser.args.gpu_device_id) - - if args_parser.args.cuda_malloc: + if args.cuda_malloc: from modules_forge.cuda_malloc import try_cuda_malloc try_cuda_malloc() diff --git a/modules_forge/ops.py b/modules_forge/ops.py index a0a0d171..aac2f259 100644 --- a/modules_forge/ops.py +++ b/modules_forge/ops.py @@ -1,15 +1,15 @@ import time import torch import contextlib -from ldm_patched.modules import model_management -from ldm_patched.modules.ops import use_patched_ops + +from backend import memory_management @contextlib.contextmanager def automatic_memory_management(): - model_management.free_memory( + memory_management.free_memory( memory_required=3 * 1024 * 1024 * 1024, - device=model_management.get_torch_device() + device=memory_management.get_torch_device() ) module_list = [] @@ -39,7 +39,7 @@ def automatic_memory_management(): for module in module_list: module.cpu() - model_management.soft_empty_cache() + memory_management.soft_empty_cache() end = time.perf_counter() print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') diff --git a/modules_forge/shared.py b/modules_forge/shared.py index a5aee828..76038866 100644 --- a/modules_forge/shared.py +++ b/modules_forge/shared.py @@ -1,7 +1,7 @@ import os -import ldm_patched.modules.utils import argparse +from backend import utils from modules.paths_internal import models_path from pathlib import Path @@ -57,7 +57,7 @@ def add_supported_control_model(control_model): def try_load_supported_control_model(ckpt_path): global supported_control_models - state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True) + state_dict = utils.load_torch_file(ckpt_path, safe_load=True) for supported_type in supported_control_models: state_dict_copy = {k: v for k, v in state_dict.items()} model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path) diff --git a/modules_forge/supported_preprocessor.py b/modules_forge/supported_preprocessor.py index e9d9eb51..628b14b0 100644 --- a/modules_forge/supported_preprocessor.py +++ b/modules_forge/supported_preprocessor.py @@ -2,10 +2,10 @@ import cv2 import torch from modules_forge.shared import add_supported_preprocessor, preprocessor_dir -from ldm_patched.modules import model_management -from ldm_patched.modules.model_patcher import ModelPatcher +from backend import memory_management +from backend.patcher.base import ModelPatcher +from backend.patcher import clipvision from modules_forge.forge_util import resize_image_with_pad -import ldm_patched.modules.clip_vision from modules.modelloader import load_file_from_url from modules_forge.forge_util import numpy_to_pytorch @@ -37,12 +37,12 @@ class Preprocessor: def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs): if load_device is None: - load_device = model_management.get_torch_device() + load_device = memory_management.get_torch_device() if offload_device is None: offload_device = torch.device('cpu') - if not model_management.should_use_fp16(load_device): + if not memory_management.should_use_fp16(load_device): dtype = torch.float32 model.eval() @@ -53,7 +53,7 @@ class Preprocessor: return self.model_patcher def move_all_model_patchers_to_gpu(self): - model_management.load_models_gpu([self.model_patcher]) + memory_management.load_models_gpu([self.model_patcher]) return def send_tensor_to_model_device(self, x): @@ -127,7 +127,7 @@ class PreprocessorClipVision(Preprocessor): if ckpt_path in PreprocessorClipVision.global_cache: self.clipvision = PreprocessorClipVision.global_cache[ckpt_path] else: - self.clipvision = ldm_patched.modules.clip_vision.load(ckpt_path) + self.clipvision = clipvision.load(ckpt_path) PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision return self.clipvision