diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index a0809852..363199f7 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -14,6 +14,7 @@ import gradio as gr from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_schedulers, errors from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state +from modules.sd_models import model_data, select_checkpoint import modules.shared as shared import modules.sd_samplers import modules.sd_models @@ -77,7 +78,29 @@ def apply_checkpoint(p, x, xs): info = modules.sd_models.get_closet_checkpoint_match(x) if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") + # skip if the checkpoint was last override + if info.name == p.override_settings.get('sd_model_checkpoint', None): + return + org_cp = getattr(opts, 'sd_model_checkpoint', None) p.override_settings['sd_model_checkpoint'] = info.name + opts.set('sd_model_checkpoint', info.name) + refresh_loading_params_for_xyz_grid() + # This saves part of the reload + opts.set('sd_model_checkpoint', org_cp) + +def refresh_loading_params_for_xyz_grid(): + """ + Refreshes the loading parameters for the model, + prompts a reload in sd_models.forge_model_reload() + """ + checkpoint_info = select_checkpoint() + + model_data.forge_loading_parameters = dict( + checkpoint_info=checkpoint_info, + additional_modules=shared.opts.forge_additional_modules, + #unet_storage_dtype=shared.opts.forge_unet_storage_dtype + unet_storage_dtype=model_data.forge_loading_parameters.get('unet_storage_dtype', None) + ) def confirm_checkpoints(p, xs): @@ -783,6 +806,9 @@ class Script(scripts.Script): second_axes_processed=second_axes_processed, margin_size=margin_size ) + + # reset loading params to previous state + refresh_loading_params_for_xyz_grid() if not processed.images: # It broke, no further handling needed.