move to new backend - part 2

This commit is contained in:
layerdiffusion
2024-08-03 15:10:37 -07:00
parent 8a01b2c5db
commit 4add428e25
9 changed files with 61 additions and 69 deletions

View File

@@ -1,22 +1,21 @@
import torch
import ldm_patched.modules.ops as ops
from backend import operations, memory_management
from backend.patcher.base import ModelPatcher
from ldm_patched.modules.model_patcher import ModelPatcher
from ldm_patched.modules import model_management
from transformers import modeling_utils
class DiffusersModelPatcher:
def __init__(self, pipeline_class, dtype=torch.float16, *args, **kwargs):
load_device = model_management.get_torch_device()
load_device = memory_management.get_torch_device()
offload_device = torch.device("cpu")
if not model_management.should_use_fp16(device=load_device):
if not memory_management.should_use_fp16(device=load_device):
dtype = torch.float32
self.dtype = dtype
with ops.use_patched_ops(ops.manual_cast):
with operations.using_forge_operations():
with modeling_utils.no_init_weights():
self.pipeline = pipeline_class.from_pretrained(*args, **kwargs)
@@ -41,7 +40,7 @@ class DiffusersModelPatcher:
def prepare_memory_before_sampling(self, batchsize, latent_width, latent_height):
area = 2 * batchsize * latent_width * latent_height
inference_memory = (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
model_management.load_models_gpu(
memory_management.load_models_gpu(
models=[self.patcher],
memory_required=inference_memory
)

View File

@@ -1,22 +1,23 @@
from modules import sd_samplers_kdiffusion, sd_samplers_common
from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
def __init__(self, sd_model, sampler_name):
self.sampler_name = sampler_name
self.unet = sd_model.forge_objects.unet
sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
super().__init__(sampler_function, sd_model, None)
def build_constructor(sampler_name):
def constructor(m):
return AlterSampler(m, sampler_name)
return constructor
# from modules import sd_samplers_kdiffusion, sd_samplers_common
# from ldm_patched.k_diffusion import sampling as k_diffusion_sampling
#
#
# class AlterSampler(sd_samplers_kdiffusion.KDiffusionSampler):
# def __init__(self, sd_model, sampler_name):
# self.sampler_name = sampler_name
# self.unet = sd_model.forge_objects.unet
# sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))
# super().__init__(sampler_function, sd_model, None)
#
#
# def build_constructor(sampler_name):
# def constructor(m):
# return AlterSampler(m, sampler_name)
#
# return constructor
#
#
samplers_data_alter = [
sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
# sd_samplers_common.SamplerData('DDPM', build_constructor(sampler_name='ddpm'), ['ddpm'], {}),
]

View File

@@ -1,5 +1,5 @@
from modules.sd_hijack_clip import FrozenCLIPEmbedderWithCustomWords
from ldm_patched.modules import model_management
from backend import memory_management
from modules import sd_models
from modules.shared import opts
@@ -9,7 +9,7 @@ def move_clip_to_gpu():
print('Error: CLIP called before SD is loaded!')
return
model_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher)
memory_management.load_model_gpu(sd_models.model_data.sd_model.forge_objects.clip.patcher)
return

View File

@@ -1,15 +1,10 @@
import torch
import contextlib
from ldm_patched.modules import model_management
from ldm_patched.modules import model_detection
from ldm_patched.modules.sd import VAE, load_model_weights
from backend import memory_management, utils
from backend.patcher.clip import CLIP
from backend.patcher.vae import VAE
import ldm_patched.modules.model_patcher
import ldm_patched.modules.utils
import ldm_patched.modules.clip_vision
from backend.patcher.base import ModelPatcher
import backend.nn.unet
from omegaconf import OmegaConf
@@ -20,7 +15,6 @@ from modules.sd_models_xl import extend_sdxl
from ldm.util import instantiate_from_config
from modules_forge import forge_clip
from modules_forge.unet_patcher import UnetPatcher
from ldm_patched.modules.model_base import model_sampling, ModelType
from backend.loader import load_huggingface_components
from backend.modules.k_model import KModel
@@ -85,13 +79,13 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
model_patcher = None
clip_target = None
parameters = ldm_patched.modules.utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
parameters = utils.calculate_parameters(sd, "model.diffusion_model.")
unet_dtype = memory_management.unet_dtype(model_params=parameters)
load_device = memory_management.get_torch_device()
manual_cast_dtype = memory_management.unet_manual_cast(unet_dtype, load_device)
manual_cast_dtype = unet_dtype if manual_cast_dtype is None else manual_cast_dtype
initial_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
initial_load_device = memory_management.unet_inital_load_device(parameters, unet_dtype)
backend.nn.unet.unet_initial_device = initial_load_device
backend.nn.unet.unet_initial_dtype = unet_dtype
@@ -101,7 +95,7 @@ def load_checkpoint_guess_config(sd, output_vae=True, output_clip=True, output_c
k_model = KModel(huggingface_components, storage_dtype=unet_dtype, computation_dtype=manual_cast_dtype)
k_model.to(device=initial_load_device, dtype=unet_dtype)
model_patcher = UnetPatcher(k_model, load_device=load_device,
offload_device=model_management.unet_offload_device(),
offload_device=memory_management.unet_offload_device(),
current_device=initial_load_device)
if output_vae:

View File

@@ -6,17 +6,17 @@ import random
import string
import cv2
from ldm_patched.modules import model_management
from backend import memory_management
def prepare_free_memory(aggressive=False):
if aggressive:
model_management.unload_all_models()
memory_management.unload_all_models()
print('Cleanup all memory.')
return
model_management.free_memory(memory_required=model_management.minimum_inference_memory(),
device=model_management.get_torch_device())
memory_management.free_memory(memory_required=memory_management.minimum_inference_memory(),
device=memory_management.get_torch_device())
print('Cleanup minimal inference memory.')
return
@@ -151,6 +151,6 @@ def resize_image_with_pad(img, resolution):
def lazy_memory_management(model):
required_memory = model_management.module_size(model) + model_management.minimum_inference_memory()
model_management.free_memory(required_memory, device=model_management.get_torch_device())
required_memory = memory_management.module_size(model) + memory_management.minimum_inference_memory()
memory_management.free_memory(required_memory, device=memory_management.get_torch_device())
return

View File

@@ -35,15 +35,13 @@ def initialize_forge():
print(f'In extreme cases, if you want to force previous lowvram/medvram behaviors, '
f'please use --always-offload-from-vram')
from ldm_patched.modules import args_parser
from backend.args import args
args_parser.args, _ = args_parser.parser.parse_known_args()
if args.gpu_device_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
print("Set device to:", args.gpu_device_id)
if args_parser.args.gpu_device_id is not None:
os.environ['CUDA_VISIBLE_DEVICES'] = str(args_parser.args.gpu_device_id)
print("Set device to:", args_parser.args.gpu_device_id)
if args_parser.args.cuda_malloc:
if args.cuda_malloc:
from modules_forge.cuda_malloc import try_cuda_malloc
try_cuda_malloc()

View File

@@ -1,15 +1,15 @@
import time
import torch
import contextlib
from ldm_patched.modules import model_management
from ldm_patched.modules.ops import use_patched_ops
from backend import memory_management
@contextlib.contextmanager
def automatic_memory_management():
model_management.free_memory(
memory_management.free_memory(
memory_required=3 * 1024 * 1024 * 1024,
device=model_management.get_torch_device()
device=memory_management.get_torch_device()
)
module_list = []
@@ -39,7 +39,7 @@ def automatic_memory_management():
for module in module_list:
module.cpu()
model_management.soft_empty_cache()
memory_management.soft_empty_cache()
end = time.perf_counter()
print(f'Automatic Memory Management: {len(module_list)} Modules in {(end - start):.2f} seconds.')

View File

@@ -1,7 +1,7 @@
import os
import ldm_patched.modules.utils
import argparse
from backend import utils
from modules.paths_internal import models_path
from pathlib import Path
@@ -57,7 +57,7 @@ def add_supported_control_model(control_model):
def try_load_supported_control_model(ckpt_path):
global supported_control_models
state_dict = ldm_patched.modules.utils.load_torch_file(ckpt_path, safe_load=True)
state_dict = utils.load_torch_file(ckpt_path, safe_load=True)
for supported_type in supported_control_models:
state_dict_copy = {k: v for k, v in state_dict.items()}
model = supported_type.try_build_from_state_dict(state_dict_copy, ckpt_path)

View File

@@ -2,10 +2,10 @@ import cv2
import torch
from modules_forge.shared import add_supported_preprocessor, preprocessor_dir
from ldm_patched.modules import model_management
from ldm_patched.modules.model_patcher import ModelPatcher
from backend import memory_management
from backend.patcher.base import ModelPatcher
from backend.patcher import clipvision
from modules_forge.forge_util import resize_image_with_pad
import ldm_patched.modules.clip_vision
from modules.modelloader import load_file_from_url
from modules_forge.forge_util import numpy_to_pytorch
@@ -37,12 +37,12 @@ class Preprocessor:
def setup_model_patcher(self, model, load_device=None, offload_device=None, dtype=torch.float32, **kwargs):
if load_device is None:
load_device = model_management.get_torch_device()
load_device = memory_management.get_torch_device()
if offload_device is None:
offload_device = torch.device('cpu')
if not model_management.should_use_fp16(load_device):
if not memory_management.should_use_fp16(load_device):
dtype = torch.float32
model.eval()
@@ -53,7 +53,7 @@ class Preprocessor:
return self.model_patcher
def move_all_model_patchers_to_gpu(self):
model_management.load_models_gpu([self.model_patcher])
memory_management.load_models_gpu([self.model_patcher])
return
def send_tensor_to_model_device(self, x):
@@ -127,7 +127,7 @@ class PreprocessorClipVision(Preprocessor):
if ckpt_path in PreprocessorClipVision.global_cache:
self.clipvision = PreprocessorClipVision.global_cache[ckpt_path]
else:
self.clipvision = ldm_patched.modules.clip_vision.load(ckpt_path)
self.clipvision = clipvision.load(ckpt_path)
PreprocessorClipVision.global_cache[ckpt_path] = self.clipvision
return self.clipvision