diff --git a/modules_forge/shared.py b/modules_forge/shared.py index 57c4b9d6..cdf8802f 100644 --- a/modules_forge/shared.py +++ b/modules_forge/shared.py @@ -1,7 +1,10 @@ import cv2 import os +import torch from modules.paths import models_path +from ldm_patched.modules import model_management +from ldm_patched.modules.model_patcher import ModelPatcher controlnet_dir = os.path.join(models_path, 'ControlNet') @@ -33,9 +36,17 @@ class PreprocessorBase: self.slider_1 = PreprocessorParameter() self.slider_2 = PreprocessorParameter() self.slider_3 = PreprocessorParameter() + self.model_patcher = None - def __call__(self, input_image, slider_1=None, slider_2=None, slider_3=None, **kwargs): - return input_image + def setup_model_patcher(self, model, load_device=None, offload_device=None, **kwargs): + if load_device is None: + load_device = model_management.get_torch_device() + + if offload_device is None: + offload_device = torch.device('cpu') + + self.model_patcher = ModelPatcher(model=model, load_device=load_device, offload_device=offload_device, **kwargs) + return class PreprocessorNone(PreprocessorBase):