From 96f264ec6a8829cac9a12c9a75b5e8fa1e363a45 Mon Sep 17 00:00:00 2001 From: layerdiffusion <19834515+lllyasviel@users.noreply.github.com> Date: Mon, 19 Aug 2024 06:30:49 -0700 Subject: [PATCH] add a way to save models --- backend/diffusion_engine/base.py | 22 +++++++++++++++++++++ backend/utils.py | 13 ++++++++++++ modules/ui_checkpoint_merger.py | 34 +++++++++++++++++++++++++++++++- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/backend/diffusion_engine/base.py b/backend/diffusion_engine/base.py index 8ec9509b..89a6055e 100644 --- a/backend/diffusion_engine/base.py +++ b/backend/diffusion_engine/base.py @@ -1,4 +1,7 @@ import torch +import safetensors.torch as sf + +from backend import utils class ForgeObjects: @@ -63,3 +66,22 @@ class ForgeDiffusionEngine: self.is_sdxl = False self.is_sd3 = False return + + def save_unet(self, filename): + sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model) + sf.save_file(sd, filename) + return filename + + def save_checkpoint(self, filename): + sd = {} + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.') + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.') + ) + sf.save_file(sd, filename) + return filename diff --git a/backend/utils.py b/backend/utils.py index 2d15a76a..56535e0c 100644 --- a/backend/utils.py +++ b/backend/utils.py @@ -138,3 +138,16 @@ def nested_move_to_device(obj, device): elif isinstance(obj, torch.Tensor): return obj.to(device) return obj + + +def get_state_dict_after_quant(model, prefix=''): + for m in model.modules(): + if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'): + if not m.weight.bnb_quantized: + original_device = m.weight.device + m.cuda() + m.to(original_device) + + sd = model.state_dict() + sd = {(prefix + k): v.clone() for k, v in sd.items()} + return sd diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py index f9c5dd6b..f6de282a 100644 --- a/modules/ui_checkpoint_merger.py +++ b/modules/ui_checkpoint_merger.py @@ -1,4 +1,4 @@ - +import os import gradio as gr from modules import sd_models, sd_vae, errors, extras, call_queue @@ -29,6 +29,38 @@ def modelmerger(*args): class UiCheckpointMerger: def __init__(self): with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Accordion(open=True, label='Save Current Checkpoint (including all quantization)'): + with gr.Row(): + textbox_file_name_forge = gr.Textbox(label="Filename (will save in /models/Stable-diffusion)", value='my_model.safetensors') + btn_save_unet_forge = gr.Button('Save UNet') + btn_save_ckpt_forge = gr.Button('Save Checkpoint') + + with gr.Row(): + result_html = gr.HTML('Ready to save ... (Currently only support saving Flux models)') + + def save_unet(filename): + from modules.paths import models_path + long_filename = os.path.join(models_path, 'Stable-diffusion', filename) + os.makedirs(os.path.dirname(long_filename), exist_ok=True) + from modules import shared, sd_models + sd_models.forge_model_reload() + p = shared.sd_model.save_unet(long_filename) + print(f'Saved UNet at: {p}') + return f'Saved UNet at: {p}' + + def save_checkpoint(filename): + from modules.paths import models_path + long_filename = os.path.join(models_path, 'Stable-diffusion', filename) + os.makedirs(os.path.dirname(long_filename), exist_ok=True) + from modules import shared + sd_models.forge_model_reload() + p = shared.sd_model.save_checkpoint(long_filename) + print(f'Saved checkpoint at: {p}') + return f'Saved checkpoint at: {p}' + + btn_save_unet_forge.click(save_unet, inputs=textbox_file_name_forge, outputs=result_html) + btn_save_ckpt_forge.click(save_checkpoint, inputs=textbox_file_name_forge, outputs=result_html) + with gr.Row(equal_height=False): with gr.Column(variant='compact'): self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")