diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index a7d72ef2..8ec9509b 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -29,19 +29,8 @@ class ForgeDiffusionEngine: self.forge_objects_after_applying_lora = None self.current_lora_hash = str([]) - self.tiling_enabled = False - self.first_stage_model = None # set this so that you can change VAE in UI - self.use_distilled_cfg_scale = False - - # WebUI Dirty Legacy - self.is_sd1 = False - self.is_sd2 = False - self.is_sdxl = False - self.is_sd3 = False - - def is_webui_legacy_model(self): - return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3 + self.fix_for_webui_backward_compatibility() def set_clip_skip(self, clip_skip): pass @@ -59,4 +48,18 @@ class ForgeDiffusionEngine: pass def get_prompt_lengths_on_ui(self, prompt): - pass + return 0, 75 + + def is_webui_legacy_model(self): + return self.is_sd1 or self.is_sd2 or self.is_sdxl or self.is_sd3 + + def fix_for_webui_backward_compatibility(self): + self.tiling_enabled = False + self.first_stage_model = None + self.cond_stage_model = None + self.use_distilled_cfg_scale = False + self.is_sd1 = False + self.is_sd2 = False + self.is_sdxl = False + self.is_sd3 = False + return diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 192810f4..654f576b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,41 +1,32 @@ -# class StableDiffusionModelHijack: -# fixes = None -# layers = None -# circular_enabled = False -# clip = None -# optimization_method = None -# -# def __init__(self): -# self.extra_generation_params = {} -# self.comments = [] -# -# def apply_optimizations(self, option=None): -# pass -# -# def convert_sdxl_to_ssd(self, m): -# pass -# -# def hijack(self, m): -# pass -# -# def undo_hijack(self, m): -# pass -# -# def apply_circular(self, enable): -# pass -# -# def clear_comments(self): -# self.comments = [] -# self.extra_generation_params = {} -# -# def get_prompt_lengths(self, text, cond_stage_model): -# pass -# -# def redo_hijack(self, m): -# pass -# -# -# model_hijack = StableDiffusionModelHijack() +class StableDiffusionModelHijack: + + def apply_optimizations(self, option=None): + pass + + def convert_sdxl_to_ssd(self, m): + pass + + def hijack(self, m): + pass + + def undo_hijack(self, m): + pass + + def apply_circular(self, enable): + pass + + def clear_comments(self): + pass + + def get_prompt_lengths(self, text, cond_stage_model): + from modules import shared + return shared.sd_model.get_prompt_lengths_on_ui(text) + + def redo_hijack(self, m): + pass + + +model_hijack = StableDiffusionModelHijack() # import torch # from torch.nn.functional import silu