mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-07 22:19:49 +00:00
add a way to save models
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import torch
|
||||
import safetensors.torch as sf
|
||||
|
||||
from backend import utils
|
||||
|
||||
|
||||
class ForgeObjects:
|
||||
@@ -63,3 +66,22 @@ class ForgeDiffusionEngine:
|
||||
self.is_sdxl = False
|
||||
self.is_sd3 = False
|
||||
return
|
||||
|
||||
def save_unet(self, filename):
|
||||
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -138,3 +138,16 @@ def nested_move_to_device(obj, device):
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
return obj
|
||||
|
||||
|
||||
def get_state_dict_after_quant(model, prefix=''):
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'):
|
||||
if not m.weight.bnb_quantized:
|
||||
original_device = m.weight.device
|
||||
m.cuda()
|
||||
m.to(original_device)
|
||||
|
||||
sd = model.state_dict()
|
||||
sd = {(prefix + k): v.clone() for k, v in sd.items()}
|
||||
return sd
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
|
||||
import os
|
||||
import gradio as gr
|
||||
|
||||
from modules import sd_models, sd_vae, errors, extras, call_queue
|
||||
@@ -29,6 +29,38 @@ def modelmerger(*args):
|
||||
class UiCheckpointMerger:
|
||||
def __init__(self):
|
||||
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
|
||||
with gr.Accordion(open=True, label='Save Current Checkpoint (including all quantization)'):
|
||||
with gr.Row():
|
||||
textbox_file_name_forge = gr.Textbox(label="Filename (will save in /models/Stable-diffusion)", value='my_model.safetensors')
|
||||
btn_save_unet_forge = gr.Button('Save UNet')
|
||||
btn_save_ckpt_forge = gr.Button('Save Checkpoint')
|
||||
|
||||
with gr.Row():
|
||||
result_html = gr.HTML('Ready to save ... (Currently only support saving Flux models)')
|
||||
|
||||
def save_unet(filename):
|
||||
from modules.paths import models_path
|
||||
long_filename = os.path.join(models_path, 'Stable-diffusion', filename)
|
||||
os.makedirs(os.path.dirname(long_filename), exist_ok=True)
|
||||
from modules import shared, sd_models
|
||||
sd_models.forge_model_reload()
|
||||
p = shared.sd_model.save_unet(long_filename)
|
||||
print(f'Saved UNet at: {p}')
|
||||
return f'Saved UNet at: {p}'
|
||||
|
||||
def save_checkpoint(filename):
|
||||
from modules.paths import models_path
|
||||
long_filename = os.path.join(models_path, 'Stable-diffusion', filename)
|
||||
os.makedirs(os.path.dirname(long_filename), exist_ok=True)
|
||||
from modules import shared
|
||||
sd_models.forge_model_reload()
|
||||
p = shared.sd_model.save_checkpoint(long_filename)
|
||||
print(f'Saved checkpoint at: {p}')
|
||||
return f'Saved checkpoint at: {p}'
|
||||
|
||||
btn_save_unet_forge.click(save_unet, inputs=textbox_file_name_forge, outputs=result_html)
|
||||
btn_save_ckpt_forge.click(save_checkpoint, inputs=textbox_file_name_forge, outputs=result_html)
|
||||
|
||||
with gr.Row(equal_height=False):
|
||||
with gr.Column(variant='compact'):
|
||||
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
|
||||
|
||||
Reference in New Issue
Block a user