Finally removed model_hijack

finally
This commit is contained in:
layerdiffusion
2024-08-05 21:05:25 -07:00
parent 252d437f5d
commit ae1d995d0d
8 changed files with 90 additions and 89 deletions

View File

@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest
import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images

View File

@@ -177,7 +177,7 @@ def configure_sigint_handler():
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules import shared, sd_models, sd_vae, ui_tempdir
from modules.call_queue import wrap_queued_call
from modules_forge import main_thread
@@ -186,7 +186,7 @@ def configure_opts_onchange():
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
# shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
startup_timer.record("opts onchange")

View File

@@ -18,7 +18,6 @@ from typing import Any
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -451,7 +450,7 @@ class StableDiffusionProcessing:
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
if len(cache) > 2:
modules.sd_hijack.model_hijack.extra_generation_params.update(cache[2])
shared.sd_model.extra_generation_params.update(cache[2])
return cache[1]
cache = caches[0]
@@ -465,7 +464,7 @@ class StableDiffusionProcessing:
last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy()
modules.sd_hijack.model_hijack.extra_generation_params.update(last_extra_generation_params)
shared.sd_model.extra_generation_params.update(last_extra_generation_params)
if len(cache) > 2:
cache[2] = last_extra_generation_params
@@ -835,7 +834,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
apply_circular_forge(p.sd_model, p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
p.sd_model.comments = []
p.sd_model.extra_generation_params = {}
p.fill_fields_from_opts()
p.setup_prompts()
@@ -911,7 +911,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.setup_conds()
p.extra_generation_params.update(model_hijack.extra_generation_params)
p.extra_generation_params.update(p.sd_model.extra_generation_params)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
@@ -922,7 +922,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [])
file.write(processed.infotext(p, 0))
for comment in model_hijack.comments:
for comment in p.sd_model.comments:
p.comment(comment)
if p.n_iter > 1:
@@ -1628,7 +1628,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.mask_for_overlay = None
self.inpaint_full_res = False
massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
model_hijack.comments.append(massage)
self.sd_model.comments.append(massage)
logging.info(massage)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)

View File

@@ -1,41 +1,41 @@
class StableDiffusionModelHijack:
fixes = None
layers = None
circular_enabled = False
clip = None
optimization_method = None
def __init__(self):
self.extra_generation_params = {}
self.comments = []
def apply_optimizations(self, option=None):
pass
def convert_sdxl_to_ssd(self, m):
pass
def hijack(self, m):
pass
def undo_hijack(self, m):
pass
def apply_circular(self, enable):
pass
def clear_comments(self):
self.comments = []
self.extra_generation_params = {}
def get_prompt_lengths(self, text, cond_stage_model):
pass
def redo_hijack(self, m):
pass
model_hijack = StableDiffusionModelHijack()
# class StableDiffusionModelHijack:
# fixes = None
# layers = None
# circular_enabled = False
# clip = None
# optimization_method = None
#
# def __init__(self):
# self.extra_generation_params = {}
# self.comments = []
#
# def apply_optimizations(self, option=None):
# pass
#
# def convert_sdxl_to_ssd(self, m):
# pass
#
# def hijack(self, m):
# pass
#
# def undo_hijack(self, m):
# pass
#
# def apply_circular(self, enable):
# pass
#
# def clear_comments(self):
# self.comments = []
# self.extra_generation_params = {}
#
# def get_prompt_lengths(self, text, cond_stage_model):
# pass
#
# def redo_hijack(self, m):
# pass
#
#
# model_hijack = StableDiffusionModelHijack()
# import torch
# from torch.nn.functional import silu

View File

@@ -562,6 +562,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = forge_loader(state_dict)
timer.record("forge model load")
sd_model.extra_generation_params = {}
sd_model.comments = []
sd_model.sd_checkpoint_info = checkpoint_info
sd_model.filename = checkpoint_info.filename
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()

View File

@@ -262,11 +262,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
if loaded_vae_file == vae_file:
return
sd_hijack.model_hijack.undo_hijack(sd_model)
# sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file, vae_source)
sd_hijack.model_hijack.hijack(sd_model)
# sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)

View File

@@ -1,38 +1,38 @@
import html
import gradio as gr
import modules.textual_inversion.textual_inversion
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try:
if not apply_optimizations:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
if not apply_optimizations:
sd_hijack.apply_optimizations()
# import html
#
# import gradio as gr
#
# import modules.textual_inversion.textual_inversion
# from modules import sd_hijack, shared
#
#
# def create_embedding(name, initialization_text, nvpt, overwrite_old):
# filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
#
# sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
#
# return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
#
#
# def train_embedding(*args):
#
# assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
#
# apply_optimizations = shared.opts.training_xattention_optimizations
# try:
# if not apply_optimizations:
# sd_hijack.undo_optimizations()
#
# embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
#
# res = f"""
# Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
# Embedding saved to {html.escape(filename)}
# """
# return res, ""
# except Exception:
# raise
# finally:
# if not apply_optimizations:
# sd_hijack.apply_optimizations()
#

View File

@@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts
import modules.infotext_utils as parameters_copypaste
import modules.shared as shared
from modules import prompt_parser
from modules.sd_hijack import model_hijack
from modules.infotext_utils import image_from_url_text, PasteField
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head