move to new backend - part 2

This commit is contained in:
layerdiffusion
2024-08-03 15:10:37 -07:00
parent 8a01b2c5db
commit 4add428e25
9 changed files with 61 additions and 69 deletions

View File

@@ -1,22 +1,21 @@
import torch import torch
import ldm_patched.modules.ops as ops from backend import operations, memory_management
from backend.patcher.base import ModelPatcher
from ldm_patched.modules.model_patcher import ModelPatcher
from ldm_patched.modules import model_management
from transformers import modeling_utils from transformers import modeling_utils
class DiffusersModelPatcher: class DiffusersModelPatcher:
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs): def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
load_device = model_management.get_torch_device() load_device = memory_management.get_torch_device()
offload_device = torch.device("cpu") offload_device = torch.device("cpu")
if not model_management.should_use_fp16(device=load_device): if not memory_management.should_use_fp16(device=load_device):
dtype = torch.float32 dtype = torch.float32
self.dtype = dtype self.dtype = dtype
with ops.use_patched_ops(ops.manual_cast): with operations.using_forge_operations():
with modeling_utils.no_init_weights(): with modeling_utils.no_init_weights():
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs) self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
@@ -41,7 +40,7 @@ class DiffusersModelPatcher:
def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height): def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height):
area = 2 * batchsize * latent_width * latent_height area = 2 * batchsize * latent_width * latent_height
inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024) inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
model_management.load_models_gpu( memory_management.load_models_gpu(
models=[self.patcher], models=[self.patcher],
memory_required=inference_memory memory_required=inference_memory
) )

View File

@@ -1,22 +1,23 @@
from modules import sd_samplers_kdiffusion, sd_samplers_common # from modules import sd_samplers_kdiffusion, sd_samplers_common
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling # from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
#
#
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler): # class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
def __init__(self, sd_model, sampler_name): # def __init__(self, sd_model, sampler_name):
self.sampler_name = sampler_name # self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet # self.unet = sd_model.forge_objects.unet
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) # sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None) # super().__init__(sampler_function, sd_model, None)
#
#
def build_constructor(sampler_name): # def build_constructor(sampler_name):
def constructor(m): # def constructor(m):
return AlterSampler(m, sampler_name) # return AlterSampler(m, sampler_name)
#
return constructor # return constructor
#
#
samplers_data_alter = [ samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}), # sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
] ]

View File

@@ -1,5 +1,5 @@
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
from ldm_patched.modules import model_management from backend import memory_management
from modules import sd_models from modules import sd_models
from modules.shared import opts from modules.shared import opts
@@ -9,7 +9,7 @@ def move_clip_to_gpu():
print('Error: CLIP called before SD is loaded!') print('Error: CLIP called before SD is loaded!')
return return
model_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher) memory_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher)
return return

View File

@@ -1,15 +1,10 @@
import torch import torch
import contextlib import contextlib
from ldm_patched.modules import model_management from backend import memory_management, utils
from ldm_patched.modules import model_detection
from ldm_patched.modules.sd import VAE, load_model_weights
from backend.patcher.clip import CLIP from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE from backend.patcher.vae import VAE
import ldm_patched.modules.model_patcher from backend.patcher.base import ModelPatcher
import ldm_patched.modules.utils
import ldm_patched.modules.clip_vision
import backend.nn.unet import backend.nn.unet
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -20,7 +15,6 @@ from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules_forge import forge_clip from modules_forge import forge_clip
from modules_forge.unet_patcher import UnetPatcher from modules_forge.unet_patcher import UnetPatcher
from ldm_patched.modules.model_base import model_sampling, ModelType
from backend.loader import load_huggingface_components from backend.loader import load_huggingface_components
from backend.modules.k_model import KModel from backend.modules.k_model import KModel
@@ -85,13 +79,13 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
model_patcher = None model_patcher = None
clip_target = None clip_target = None
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.") parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters) unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device() load_device = memory_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
backend.nn.unet.unet_initial_device = initial_load_device backend.nn.unet.unet_initial_device = initial_load_device
backend.nn.unet.unet_initial_dtype = unet_dtype backend.nn.unet.unet_initial_dtype = unet_dtype
@@ -101,7 +95,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype) k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
k_model.to(device=initial_load_device, dtype=unet_dtype) k_model.to(device=initial_load_device, dtype=unet_dtype)
model_patcher = UnetPatcher(k_model, load_device=load_device, model_patcher = UnetPatcher(k_model, load_device=load_device,
offload_device=model_management.unet_offload_device(), offload_device=memory_management.unet_offload_device(),
current_device=initial_load_device) current_device=initial_load_device)
if output_vae: if output_vae:

View File

@@ -6,17 +6,17 @@ import random
import string import string
import cv2 import cv2
from ldm_patched.modules import model_management from backend import memory_management
def prepare_free_memory(aggressive=False): def prepare_free_memory(aggressive=False):
if aggressive: if aggressive:
model_management.unload_all_models() memory_management.unload_all_models()
print('Cleanup all memory.') print('Cleanup all memory.')
return return
model_management.free_memory(memory_required=model_management.minimum_inference_memory(), memory_management.free_memory(memory_required=memory_management.minimum_inference_memory(),
device=model_management.get_torch_device()) device=memory_management.get_torch_device())
print('Cleanup minimal inference memory.') print('Cleanup minimal inference memory.')
return return
@@ -151,6 +151,6 @@ def resize_image_with_pad(img, resolution):
def lazy_memory_management(model): def lazy_memory_management(model):
required_memory = model_management.module_size(model) + model_management.minimum_inference_memory() required_memory = memory_management.module_size(model) + memory_management.minimum_inference_memory()
model_management.free_memory(required_memory, device=model_management.get_torch_device()) memory_management.free_memory(required_memory, device=memory_management.get_torch_device())
return return

View File

@@ -35,15 +35,13 @@ def initialize_forge():
print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, ' print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, '
f'please use --always-offload-from-vram') f'please use --always-offload-from-vram')
from ldm_patched.modules import args_parser from backend.args import args
args_parser.args, _ = args_parser.parser.parse_known_args() if args.gpu_device_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
print("Set device to:", args.gpu_device_id)
if args_parser.args.gpu_device_id is not None: if args.cuda_malloc:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args_parser.args.gpu_device_id)
print("Set device to:", args_parser.args.gpu_device_id)
if args_parser.args.cuda_malloc:
from modules_forge.cuda_malloc import try_cuda_malloc from modules_forge.cuda_malloc import try_cuda_malloc
try_cuda_malloc() try_cuda_malloc()

View File

@@ -1,15 +1,15 @@
import time import time
import torch import torch
import contextlib import contextlib
from ldm_patched.modules import model_management
from ldm_patched.modules.ops import use_patched_ops from backend import memory_management
@contextlib.contextmanager @contextlib.contextmanager
def automatic_memory_management(): def automatic_memory_management():
model_management.free_memory( memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024, memory_required=3 * 1024 * 1024 * 1024,
device=model_management.get_torch_device() device=memory_management.get_torch_device()
) )
module_list = [] module_list = []
@@ -39,7 +39,7 @@ def automatic_memory_management():
for module in module_list: for module in module_list:
module.cpu() module.cpu()
model_management.soft_empty_cache() memory_management.soft_empty_cache()
end = time.perf_counter() end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.') print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')

View File

@@ -1,7 +1,7 @@
import os import os
import ldm_patched.modules.utils
import argparse import argparse
from backend import utils
from modules.paths_internal import models_path from modules.paths_internal import models_path
from pathlib import Path from pathlib import Path
@@ -57,7 +57,7 @@ def add_supported_control_model(control_model):
def try_load_supported_control_model(ckpt_path): def try_load_supported_control_model(ckpt_path):
global supported_control_models global supported_control_models
state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True) state_dict = utils.load_torch_file(ckpt_path, safe_load=True)
for supported_type in supported_control_models: for supported_type in supported_control_models:
state_dict_copy = {k: v for k, v in state_dict.items()} state_dict_copy = {k: v for k, v in state_dict.items()}
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path) model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)

View File

@@ -2,10 +2,10 @@ import cv2
import torch import torch
from modules_forge.shared import add_supported_preprocessor, preprocessor_dir from modules_forge.shared import add_supported_preprocessor, preprocessor_dir
from ldm_patched.modules import model_management from backend import memory_management
from ldm_patched.modules.model_patcher import ModelPatcher from backend.patcher.base import ModelPatcher
from backend.patcher import clipvision
from modules_forge.forge_util import resize_image_with_pad from modules_forge.forge_util import resize_image_with_pad
import ldm_patched.modules.clip_vision
from modules.modelloader import load_file_from_url from modules.modelloader import load_file_from_url
from modules_forge.forge_util import numpy_to_pytorch from modules_forge.forge_util import numpy_to_pytorch
@@ -37,12 +37,12 @@ class Preprocessor:
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs): def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
if load_device is None: if load_device is None:
load_device = model_management.get_torch_device() load_device = memory_management.get_torch_device()
if offload_device is None: if offload_device is None:
offload_device = torch.device('cpu') offload_device = torch.device('cpu')
if not model_management.should_use_fp16(load_device): if not memory_management.should_use_fp16(load_device):
dtype = torch.float32 dtype = torch.float32
model.eval() model.eval()
@@ -53,7 +53,7 @@ class Preprocessor:
return self.model_patcher return self.model_patcher
def move_all_model_patchers_to_gpu(self): def move_all_model_patchers_to_gpu(self):
model_management.load_models_gpu([self.model_patcher]) memory_management.load_models_gpu([self.model_patcher])
return return
def send_tensor_to_model_device(self, x): def send_tensor_to_model_device(self, x):
@@ -127,7 +127,7 @@ class PreprocessorClipVision(Preprocessor):
if ckpt_path in PreprocessorClipVision.global_cache: if ckpt_path in PreprocessorClipVision.global_cache:
self.clipvision = PreprocessorClipVision.global_cache[ckpt_path] self.clipvision = PreprocessorClipVision.global_cache[ckpt_path]
else: else:
self.clipvision = ldm_patched.modules.clip_vision.load(ckpt_path) self.clipvision = clipvision.load(ckpt_path)
PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision
return self.clipvision return self.clipvision