From ae1d995d0d930babeff561476c06bb10b2d172bf Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 5 Aug 2024 21:05:25 -0700 Subject: [PATCH] Finally removed model_hijack finally --- modules/api/api.py | 2 +- modules/initialize_util.py | 4 +- modules/processing.py | 14 +++--- modules/sd_hijack.py | 76 ++++++++++++++++----------------- modules/sd_models.py | 2 + modules/sd_vae.py | 4 +- modules/textual_inversion/ui.py | 76 ++++++++++++++++----------------- modules/ui.py | 1 - 8 files changed, 90 insertions(+), 89 deletions(-) diff --git a/modules/api/api.py b/modules/api/api.py index 3316a1f3..33ce1d6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers +from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 693b083c..0458e222 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,7 +177,7 @@ def configure_sigint_handler(): def configure_opts_onchange(): - from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack + from modules import shared, sd_models, sd_vae, ui_tempdir from modules.call_queue import wrap_queued_call from modules_forge import main_thread @@ -186,7 +186,7 @@ def configure_opts_onchange(): shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) - shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + # shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False) startup_timer.record("opts onchange") diff --git a/modules/processing.py b/modules/processing.py index b36a8c85..fea50715 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -18,7 +18,6 @@ from typing import Any import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling from modules.rng import slerp # noqa: F401 -from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -451,7 +450,7 @@ class StableDiffusionProcessing: for cache in caches: if cache[0] is not None and cached_params == cache[0]: if len(cache) > 2: - modules.sd_hijack.model_hijack.extra_generation_params.update(cache[2]) + shared.sd_model.extra_generation_params.update(cache[2]) return cache[1] cache = caches[0] @@ -465,7 +464,7 @@ class StableDiffusionProcessing: last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy() - modules.sd_hijack.model_hijack.extra_generation_params.update(last_extra_generation_params) + shared.sd_model.extra_generation_params.update(last_extra_generation_params) if len(cache) > 2: cache[2] = last_extra_generation_params @@ -835,7 +834,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.sd_vae_hash = sd_vae.get_loaded_vae_hash() apply_circular_forge(p.sd_model, p.tiling) - modules.sd_hijack.model_hijack.clear_comments() + p.sd_model.comments = [] + p.sd_model.extra_generation_params = {} p.fill_fields_from_opts() p.setup_prompts() @@ -911,7 +911,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - p.extra_generation_params.update(model_hijack.extra_generation_params) + p.extra_generation_params.update(p.sd_model.extra_generation_params) # params.txt should be saved after scripts.process_batch, since the # infotext could be modified by that callback @@ -922,7 +922,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: processed = Processed(p, []) file.write(processed.infotext(p, 0)) - for comment in model_hijack.comments: + for comment in p.sd_model.comments: p.comment(comment) if p.n_iter > 1: @@ -1628,7 +1628,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.mask_for_overlay = None self.inpaint_full_res = False massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.' - model_hijack.comments.append(massage) + self.sd_model.comments.append(massage) logging.info(massage) else: image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 23d5ed6a..192810f4 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,41 +1,41 @@ -class StableDiffusionModelHijack: - fixes = None - layers = None - circular_enabled = False - clip = None - optimization_method = None - - def __init__(self): - self.extra_generation_params = {} - self.comments = [] - - def apply_optimizations(self, option=None): - pass - - def convert_sdxl_to_ssd(self, m): - pass - - def hijack(self, m): - pass - - def undo_hijack(self, m): - pass - - def apply_circular(self, enable): - pass - - def clear_comments(self): - self.comments = [] - self.extra_generation_params = {} - - def get_prompt_lengths(self, text, cond_stage_model): - pass - - def redo_hijack(self, m): - pass - - -model_hijack = StableDiffusionModelHijack() +# class StableDiffusionModelHijack: +# fixes = None +# layers = None +# circular_enabled = False +# clip = None +# optimization_method = None +# +# def __init__(self): +# self.extra_generation_params = {} +# self.comments = [] +# +# def apply_optimizations(self, option=None): +# pass +# +# def convert_sdxl_to_ssd(self, m): +# pass +# +# def hijack(self, m): +# pass +# +# def undo_hijack(self, m): +# pass +# +# def apply_circular(self, enable): +# pass +# +# def clear_comments(self): +# self.comments = [] +# self.extra_generation_params = {} +# +# def get_prompt_lengths(self, text, cond_stage_model): +# pass +# +# def redo_hijack(self, m): +# pass +# +# +# model_hijack = StableDiffusionModelHijack() # import torch # from torch.nn.functional import silu diff --git a/modules/sd_models.py b/modules/sd_models.py index 51ccab8f..1de7a2ba 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -562,6 +562,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = forge_loader(state_dict) timer.record("forge model load") + sd_model.extra_generation_params = {} + sd_model.comments = [] sd_model.sd_checkpoint_info = checkpoint_info sd_model.filename = checkpoint_info.filename sd_model.sd_model_hash = checkpoint_info.calculate_shorthash() diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 62fd6524..6f17c5b6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -262,11 +262,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): if loaded_vae_file == vae_file: return - sd_hijack.model_hijack.undo_hijack(sd_model) + # sd_hijack.model_hijack.undo_hijack(sd_model) load_vae(sd_model, vae_file, vae_source) - sd_hijack.model_hijack.hijack(sd_model) + # sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/textual_inversion/ui.py b/modules/textual_inversion/ui.py index f149ad1f..0457d8af 100644 --- a/modules/textual_inversion/ui.py +++ b/modules/textual_inversion/ui.py @@ -1,38 +1,38 @@ -import html - -import gradio as gr - -import modules.textual_inversion.textual_inversion -from modules import sd_hijack, shared - - -def create_embedding(name, initialization_text, nvpt, overwrite_old): - filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) - - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - - return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" - - -def train_embedding(*args): - - assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' - - apply_optimizations = shared.opts.training_xattention_optimizations - try: - if not apply_optimizations: - sd_hijack.undo_optimizations() - - embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) - - res = f""" -Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. -Embedding saved to {html.escape(filename)} -""" - return res, "" - except Exception: - raise - finally: - if not apply_optimizations: - sd_hijack.apply_optimizations() - +# import html +# +# import gradio as gr +# +# import modules.textual_inversion.textual_inversion +# from modules import sd_hijack, shared +# +# +# def create_embedding(name, initialization_text, nvpt, overwrite_old): +# filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) +# +# sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() +# +# return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" +# +# +# def train_embedding(*args): +# +# assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' +# +# apply_optimizations = shared.opts.training_xattention_optimizations +# try: +# if not apply_optimizations: +# sd_hijack.undo_optimizations() +# +# embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) +# +# res = f""" +# Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. +# Embedding saved to {html.escape(filename)} +# """ +# return res, "" +# except Exception: +# raise +# finally: +# if not apply_optimizations: +# sd_hijack.apply_optimizations() +# diff --git a/modules/ui.py b/modules/ui.py index 341cfa7c..088f36af 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts import modules.infotext_utils as parameters_copypaste import modules.shared as shared from modules import prompt_parser -from modules.sd_hijack import model_hijack from modules.infotext_utils import image_from_url_text, PasteField from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head