From 4a30c157691a33c2cc6e8d4fe861907429428f1e Mon Sep 17 00:00:00 2001 From: DenOfEquity <166248528+DenOfEquity@users.noreply.github.com> Date: Sun, 9 Feb 2025 13:21:15 +0000 Subject: [PATCH] update for Save Checkpoint button (#2636) function in Checkpoint Merger previously produced unusable checkpoints for non-Flux architectures because keys had names not recognised by the model loader --- backend/diffusion_engine/sd15.py | 19 +++++++++++++++++++ backend/diffusion_engine/sd20.py | 19 +++++++++++++++++++ backend/diffusion_engine/sdxl.py | 19 +++++++++++++++++++ 3 files changed, 57 insertions(+) diff --git a/backend/diffusion_engine/sd15.py b/backend/diffusion_engine/sd15.py index af47eb53..8e446475 100644 --- a/backend/diffusion_engine/sd15.py +++ b/backend/diffusion_engine/sd15.py @@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine from backend.args import dynamic_args from backend import memory_management +import safetensors.torch as sf +from backend import utils + class StableDiffusion(ForgeDiffusionEngine): matched_guesses = [model_list.SD15] @@ -79,3 +82,19 @@ class StableDiffusion(ForgeDiffusionEngine): sample = self.forge_objects.vae.first_stage_model.process_out(x) sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return sample.to(x) + + def save_checkpoint(self, filename): + sd = {} + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') + ) + sd.update( + model_list.SD15.process_clip_state_dict_for_saving(self, + utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='') + ) + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.') + ) + sf.save_file(sd, filename) + return filename diff --git a/backend/diffusion_engine/sd20.py b/backend/diffusion_engine/sd20.py index adb69528..42594f72 100644 --- a/backend/diffusion_engine/sd20.py +++ b/backend/diffusion_engine/sd20.py @@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine from backend.args import dynamic_args from backend import memory_management +import safetensors.torch as sf +from backend import utils + class StableDiffusion2(ForgeDiffusionEngine): matched_guesses = [model_list.SD20] @@ -79,3 +82,19 @@ class StableDiffusion2(ForgeDiffusionEngine): sample = self.forge_objects.vae.first_stage_model.process_out(x) sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return sample.to(x) + + def save_checkpoint(self, filename): + sd = {} + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') + ) + sd.update( + model_list.SD20.process_clip_state_dict_for_saving(self, + utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='') + ) + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.') + ) + sf.save_file(sd, filename) + return filename diff --git a/backend/diffusion_engine/sdxl.py b/backend/diffusion_engine/sdxl.py index 0873da18..d12fe195 100644 --- a/backend/diffusion_engine/sdxl.py +++ b/backend/diffusion_engine/sdxl.py @@ -10,6 +10,9 @@ from backend.args import dynamic_args from backend import memory_management from backend.nn.unet import Timestep +import safetensors.torch as sf +from backend import utils + class StableDiffusionXL(ForgeDiffusionEngine): matched_guesses = [model_list.SDXL] @@ -131,3 +134,19 @@ class StableDiffusionXL(ForgeDiffusionEngine): sample = self.forge_objects.vae.first_stage_model.process_out(x) sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 return sample.to(x) + + def save_checkpoint(self, filename): + sd = {} + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.') + ) + sd.update( + model_list.SDXL.process_clip_state_dict_for_saving(self, + utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='') + ) + ) + sd.update( + utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.') + ) + sf.save_file(sd, filename) + return filename