update for Save Checkpoint button (#2636)

function in Checkpoint Merger previously produced unusable checkpoints for non-Flux architectures because keys had names not recognised by the model loader
This commit is contained in:
DenOfEquity
2025-02-09 13:21:15 +00:00
committed by GitHub
parent f3672ffbbe
commit 4a30c15769
3 changed files with 57 additions and 0 deletions

View File

@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
import safetensors.torch as sf
from backend import utils
class StableDiffusion(ForgeDiffusionEngine):
matched_guesses = [model_list.SD15]
@@ -79,3 +82,19 @@ class StableDiffusion(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SD15.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
from backend.args import dynamic_args
from backend import memory_management
import safetensors.torch as sf
from backend import utils
class StableDiffusion2(ForgeDiffusionEngine):
matched_guesses = [model_list.SD20]
@@ -79,3 +82,19 @@ class StableDiffusion2(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SD20.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename

View File

@@ -10,6 +10,9 @@ from backend.args import dynamic_args
from backend import memory_management
from backend.nn.unet import Timestep
import safetensors.torch as sf
from backend import utils
class StableDiffusionXL(ForgeDiffusionEngine):
matched_guesses = [model_list.SDXL]
@@ -131,3 +134,19 @@ class StableDiffusionXL(ForgeDiffusionEngine):
sample = self.forge_objects.vae.first_stage_model.process_out(x)
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
return sample.to(x)
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
model_list.SDXL.process_clip_state_dict_for_saving(self,
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
)
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
)
sf.save_file(sd, filename)
return filename