add a way to save models

This commit is contained in:
layerdiffusion
2024-08-19 06:30:49 -07:00
parent 4e8ba14dd0
commit 96f264ec6a
3 changed files with 68 additions and 1 deletions

View File

@@ -1,4 +1,7 @@
import torch
import safetensors.torch as sf
from backend import utils
class ForgeObjects:
@@ -63,3 +66,22 @@ class ForgeDiffusionEngine:
self.is_sdxl = False
self.is_sd3 = False
return
def save_unet(self, filename):
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
sf.save_file(sd, filename)
return filename
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -138,3 +138,16 @@ def nested_move_to_device(obj, device):
elif isinstance(obj, torch.Tensor):
return obj.to(device)
return obj
def get_state_dict_after_quant(model, prefix=''):
for m in model.modules():
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'):
if not m.weight.bnb_quantized:
original_device = m.weight.device
m.cuda()
m.to(original_device)
sd = model.state_dict()
sd = {(prefix + k): v.clone() for k, v in sd.items()}
return sd

View File

@@ -1,4 +1,4 @@
import os
import gradio as gr
from modules import sd_models, sd_vae, errors, extras, call_queue
@@ -29,6 +29,38 @@ def modelmerger(*args):
class UiCheckpointMerger:
def __init__(self):
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Accordion(open=True, label='Save Current Checkpoint (including all quantization)'):
with gr.Row():
textbox_file_name_forge = gr.Textbox(label="Filename (will save in /models/Stable-diffusion)", value='my_model.safetensors')
btn_save_unet_forge = gr.Button('Save UNet')
btn_save_ckpt_forge = gr.Button('Save Checkpoint')
with gr.Row():
result_html = gr.HTML('Ready to save ... (Currently only support saving Flux models)')
def save_unet(filename):
from modules.paths import models_path
long_filename = os.path.join(models_path, 'Stable-diffusion', filename)
os.makedirs(os.path.dirname(long_filename), exist_ok=True)
from modules import shared, sd_models
sd_models.forge_model_reload()
p = shared.sd_model.save_unet(long_filename)
print(f'Saved UNet at: {p}')
return f'Saved UNet at: {p}'
def save_checkpoint(filename):
from modules.paths import models_path
long_filename = os.path.join(models_path, 'Stable-diffusion', filename)
os.makedirs(os.path.dirname(long_filename), exist_ok=True)
from modules import shared
sd_models.forge_model_reload()
p = shared.sd_model.save_checkpoint(long_filename)
print(f'Saved checkpoint at: {p}')
return f'Saved checkpoint at: {p}'
btn_save_unet_forge.click(save_unet, inputs=textbox_file_name_forge, outputs=result_html)
btn_save_ckpt_forge.click(save_checkpoint, inputs=textbox_file_name_forge, outputs=result_html)
with gr.Row(equal_height=False):
with gr.Column(variant='compact'):
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")