mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-21 07:13:56 +00:00
rework memory management for extras
now face post-processing uses gpu close #312
This commit is contained in:
@@ -2,8 +2,9 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerDAT(Upscaler):
|
||||
|
||||
@@ -50,10 +50,10 @@ def enable_tf32():
|
||||
cpu: torch.device = torch.device("cpu")
|
||||
fp8: bool = False
|
||||
device: torch.device = model_management.get_torch_device()
|
||||
device_interrogate: torch.device = cpu # not used
|
||||
device_gfpgan: torch.device = cpu
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed in special way
|
||||
device_codeformer: torch.device = cpu
|
||||
device_interrogate: torch.device = model_management.text_encoder_device() # for backward compatibility, not used now
|
||||
device_gfpgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_esrgan: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
device_codeformer: torch.device = model_management.get_torch_device() # will be managed by memory management system
|
||||
dtype: torch.dtype = model_management.unet_dtype()
|
||||
dtype_vae: torch.dtype = model_management.vae_dtype()
|
||||
dtype_unet: torch.dtype = model_management.unet_dtype()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from modules import modelloader, devices, errors
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerESRGAN(Upscaler):
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from modules import devices, errors, face_restoration, shared
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
@@ -153,6 +154,7 @@ class CommonFaceRestoration(face_restoration.FaceRestoration):
|
||||
return np_image
|
||||
|
||||
try:
|
||||
prepare_free_memory()
|
||||
self.send_model_to(self.get_device())
|
||||
return restore_with_face_helper(np_image, self.face_helper, restore_face)
|
||||
finally:
|
||||
|
||||
@@ -3,8 +3,9 @@ import sys
|
||||
|
||||
from modules import modelloader, devices
|
||||
from modules.shared import opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerHAT(Upscaler):
|
||||
|
||||
@@ -2,8 +2,9 @@ import os
|
||||
|
||||
from modules import modelloader, errors
|
||||
from modules.shared import cmd_opts, opts
|
||||
from modules.upscaler import Upscaler, UpscalerData, prepare_free_memory
|
||||
from modules.upscaler import Upscaler, UpscalerData
|
||||
from modules.upscaler_utils import upscale_with_model
|
||||
from modules_forge.forge_util import prepare_free_memory
|
||||
|
||||
|
||||
class UpscalerRealESRGAN(Upscaler):
|
||||
|
||||
@@ -6,17 +6,6 @@ from PIL import Image
|
||||
|
||||
import modules.shared
|
||||
from modules import modelloader, shared
|
||||
from ldm_patched.modules import model_management
|
||||
|
||||
|
||||
def prepare_free_memory(aggressive=False):
|
||||
if aggressive:
|
||||
model_management.unload_all_models()
|
||||
print('Upscale script freed all memory.')
|
||||
return
|
||||
|
||||
model_management.free_memory(memory_required=1024*1024*3, device=model_management.get_torch_device())
|
||||
print('Upscale script freed memory successfully.')
|
||||
|
||||
|
||||
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
|
||||
|
||||
Reference in New Issue
Block a user