mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-28 10:11:42 +00:00
Finally removed model_hijack
finally
This commit is contained in:
@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
|
|||||||
from secrets import compare_digest
|
from secrets import compare_digest
|
||||||
|
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
|
from modules import sd_samplers, deepbooru, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, infotext_utils, sd_models, sd_schedulers
|
||||||
from modules.api import models
|
from modules.api import models
|
||||||
from modules.shared import opts
|
from modules.shared import opts
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ def configure_sigint_handler():
|
|||||||
|
|
||||||
|
|
||||||
def configure_opts_onchange():
|
def configure_opts_onchange():
|
||||||
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
|
from modules import shared, sd_models, sd_vae, ui_tempdir
|
||||||
from modules.call_queue import wrap_queued_call
|
from modules.call_queue import wrap_queued_call
|
||||||
from modules_forge import main_thread
|
from modules_forge import main_thread
|
||||||
|
|
||||||
@@ -186,7 +186,7 @@ def configure_opts_onchange():
|
|||||||
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: main_thread.run_and_wait_result(sd_vae.reload_vae_weights)), call=False)
|
||||||
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
|
||||||
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
|
||||||
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
# shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
|
||||||
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
|
||||||
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
shared.opts.onchange("cache_fp16_weight", wrap_queued_call(lambda: sd_models.reload_model_weights(forced_reload=True)), call=False)
|
||||||
startup_timer.record("opts onchange")
|
startup_timer.record("opts onchange")
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from typing import Any
|
|||||||
import modules.sd_hijack
|
import modules.sd_hijack
|
||||||
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
|
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, infotext_utils, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng, profiling
|
||||||
from modules.rng import slerp # noqa: F401
|
from modules.rng import slerp # noqa: F401
|
||||||
from modules.sd_hijack import model_hijack
|
|
||||||
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
|
||||||
from modules.shared import opts, cmd_opts, state
|
from modules.shared import opts, cmd_opts, state
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
@@ -451,7 +450,7 @@ class StableDiffusionProcessing:
|
|||||||
for cache in caches:
|
for cache in caches:
|
||||||
if cache[0] is not None and cached_params == cache[0]:
|
if cache[0] is not None and cached_params == cache[0]:
|
||||||
if len(cache) > 2:
|
if len(cache) > 2:
|
||||||
modules.sd_hijack.model_hijack.extra_generation_params.update(cache[2])
|
shared.sd_model.extra_generation_params.update(cache[2])
|
||||||
return cache[1]
|
return cache[1]
|
||||||
|
|
||||||
cache = caches[0]
|
cache = caches[0]
|
||||||
@@ -465,7 +464,7 @@ class StableDiffusionProcessing:
|
|||||||
|
|
||||||
last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy()
|
last_extra_generation_params = backend.text_processing.classic_engine.last_extra_generation_params.copy()
|
||||||
|
|
||||||
modules.sd_hijack.model_hijack.extra_generation_params.update(last_extra_generation_params)
|
shared.sd_model.extra_generation_params.update(last_extra_generation_params)
|
||||||
|
|
||||||
if len(cache) > 2:
|
if len(cache) > 2:
|
||||||
cache[2] = last_extra_generation_params
|
cache[2] = last_extra_generation_params
|
||||||
@@ -835,7 +834,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
p.sd_vae_hash = sd_vae.get_loaded_vae_hash()
|
||||||
|
|
||||||
apply_circular_forge(p.sd_model, p.tiling)
|
apply_circular_forge(p.sd_model, p.tiling)
|
||||||
modules.sd_hijack.model_hijack.clear_comments()
|
p.sd_model.comments = []
|
||||||
|
p.sd_model.extra_generation_params = {}
|
||||||
|
|
||||||
p.fill_fields_from_opts()
|
p.fill_fields_from_opts()
|
||||||
p.setup_prompts()
|
p.setup_prompts()
|
||||||
@@ -911,7 +911,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
|
|
||||||
p.setup_conds()
|
p.setup_conds()
|
||||||
|
|
||||||
p.extra_generation_params.update(model_hijack.extra_generation_params)
|
p.extra_generation_params.update(p.sd_model.extra_generation_params)
|
||||||
|
|
||||||
# params.txt should be saved after scripts.process_batch, since the
|
# params.txt should be saved after scripts.process_batch, since the
|
||||||
# infotext could be modified by that callback
|
# infotext could be modified by that callback
|
||||||
@@ -922,7 +922,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
|
|||||||
processed = Processed(p, [])
|
processed = Processed(p, [])
|
||||||
file.write(processed.infotext(p, 0))
|
file.write(processed.infotext(p, 0))
|
||||||
|
|
||||||
for comment in model_hijack.comments:
|
for comment in p.sd_model.comments:
|
||||||
p.comment(comment)
|
p.comment(comment)
|
||||||
|
|
||||||
if p.n_iter > 1:
|
if p.n_iter > 1:
|
||||||
@@ -1628,7 +1628,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
|
|||||||
self.mask_for_overlay = None
|
self.mask_for_overlay = None
|
||||||
self.inpaint_full_res = False
|
self.inpaint_full_res = False
|
||||||
massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
|
massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.'
|
||||||
model_hijack.comments.append(massage)
|
self.sd_model.comments.append(massage)
|
||||||
logging.info(massage)
|
logging.info(massage)
|
||||||
else:
|
else:
|
||||||
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
|
||||||
|
|||||||
@@ -1,41 +1,41 @@
|
|||||||
class StableDiffusionModelHijack:
|
# class StableDiffusionModelHijack:
|
||||||
fixes = None
|
# fixes = None
|
||||||
layers = None
|
# layers = None
|
||||||
circular_enabled = False
|
# circular_enabled = False
|
||||||
clip = None
|
# clip = None
|
||||||
optimization_method = None
|
# optimization_method = None
|
||||||
|
#
|
||||||
def __init__(self):
|
# def __init__(self):
|
||||||
self.extra_generation_params = {}
|
# self.extra_generation_params = {}
|
||||||
self.comments = []
|
# self.comments = []
|
||||||
|
#
|
||||||
def apply_optimizations(self, option=None):
|
# def apply_optimizations(self, option=None):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def convert_sdxl_to_ssd(self, m):
|
# def convert_sdxl_to_ssd(self, m):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def hijack(self, m):
|
# def hijack(self, m):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def undo_hijack(self, m):
|
# def undo_hijack(self, m):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def apply_circular(self, enable):
|
# def apply_circular(self, enable):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def clear_comments(self):
|
# def clear_comments(self):
|
||||||
self.comments = []
|
# self.comments = []
|
||||||
self.extra_generation_params = {}
|
# self.extra_generation_params = {}
|
||||||
|
#
|
||||||
def get_prompt_lengths(self, text, cond_stage_model):
|
# def get_prompt_lengths(self, text, cond_stage_model):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
def redo_hijack(self, m):
|
# def redo_hijack(self, m):
|
||||||
pass
|
# pass
|
||||||
|
#
|
||||||
|
#
|
||||||
model_hijack = StableDiffusionModelHijack()
|
# model_hijack = StableDiffusionModelHijack()
|
||||||
|
|
||||||
# import torch
|
# import torch
|
||||||
# from torch.nn.functional import silu
|
# from torch.nn.functional import silu
|
||||||
|
|||||||
@@ -562,6 +562,8 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
|
|||||||
sd_model = forge_loader(state_dict)
|
sd_model = forge_loader(state_dict)
|
||||||
timer.record("forge model load")
|
timer.record("forge model load")
|
||||||
|
|
||||||
|
sd_model.extra_generation_params = {}
|
||||||
|
sd_model.comments = []
|
||||||
sd_model.sd_checkpoint_info = checkpoint_info
|
sd_model.sd_checkpoint_info = checkpoint_info
|
||||||
sd_model.filename = checkpoint_info.filename
|
sd_model.filename = checkpoint_info.filename
|
||||||
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
sd_model.sd_model_hash = checkpoint_info.calculate_shorthash()
|
||||||
|
|||||||
@@ -262,11 +262,11 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
|
|||||||
if loaded_vae_file == vae_file:
|
if loaded_vae_file == vae_file:
|
||||||
return
|
return
|
||||||
|
|
||||||
sd_hijack.model_hijack.undo_hijack(sd_model)
|
# sd_hijack.model_hijack.undo_hijack(sd_model)
|
||||||
|
|
||||||
load_vae(sd_model, vae_file, vae_source)
|
load_vae(sd_model, vae_file, vae_source)
|
||||||
|
|
||||||
sd_hijack.model_hijack.hijack(sd_model)
|
# sd_hijack.model_hijack.hijack(sd_model)
|
||||||
|
|
||||||
script_callbacks.model_loaded_callback(sd_model)
|
script_callbacks.model_loaded_callback(sd_model)
|
||||||
|
|
||||||
|
|||||||
@@ -1,38 +1,38 @@
|
|||||||
import html
|
# import html
|
||||||
|
#
|
||||||
import gradio as gr
|
# import gradio as gr
|
||||||
|
#
|
||||||
import modules.textual_inversion.textual_inversion
|
# import modules.textual_inversion.textual_inversion
|
||||||
from modules import sd_hijack, shared
|
# from modules import sd_hijack, shared
|
||||||
|
#
|
||||||
|
#
|
||||||
def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
# def create_embedding(name, initialization_text, nvpt, overwrite_old):
|
||||||
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
# filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
|
||||||
|
#
|
||||||
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
# sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
||||||
|
#
|
||||||
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
# return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
||||||
|
#
|
||||||
|
#
|
||||||
def train_embedding(*args):
|
# def train_embedding(*args):
|
||||||
|
#
|
||||||
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
# assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
|
||||||
|
#
|
||||||
apply_optimizations = shared.opts.training_xattention_optimizations
|
# apply_optimizations = shared.opts.training_xattention_optimizations
|
||||||
try:
|
# try:
|
||||||
if not apply_optimizations:
|
# if not apply_optimizations:
|
||||||
sd_hijack.undo_optimizations()
|
# sd_hijack.undo_optimizations()
|
||||||
|
#
|
||||||
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
# embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
||||||
|
#
|
||||||
res = f"""
|
# res = f"""
|
||||||
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
# Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
||||||
Embedding saved to {html.escape(filename)}
|
# Embedding saved to {html.escape(filename)}
|
||||||
"""
|
# """
|
||||||
return res, ""
|
# return res, ""
|
||||||
except Exception:
|
# except Exception:
|
||||||
raise
|
# raise
|
||||||
finally:
|
# finally:
|
||||||
if not apply_optimizations:
|
# if not apply_optimizations:
|
||||||
sd_hijack.apply_optimizations()
|
# sd_hijack.apply_optimizations()
|
||||||
|
#
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts
|
|||||||
import modules.infotext_utils as parameters_copypaste
|
import modules.infotext_utils as parameters_copypaste
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
from modules import prompt_parser
|
from modules import prompt_parser
|
||||||
from modules.sd_hijack import model_hijack
|
|
||||||
from modules.infotext_utils import image_from_url_text, PasteField
|
from modules.infotext_utils import image_from_url_text, PasteField
|
||||||
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
|
from modules_forge.forge_canvas.canvas import ForgeCanvas, canvas_head
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user