mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-03 12:09:51 +00:00
update for Save Checkpoint button (#2636)
function in Checkpoint Merger previously produced unusable checkpoints for non-Flux architectures because keys had names not recognised by the model loader
This commit is contained in:
@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD15]
|
||||
@@ -79,3 +82,19 @@ class StableDiffusion(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD15.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
||||
from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusion2(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SD20]
|
||||
@@ -79,3 +82,19 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SD20.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -10,6 +10,9 @@ from backend.args import dynamic_args
|
||||
from backend import memory_management
|
||||
from backend.nn.unet import Timestep
|
||||
|
||||
import safetensors.torch as sf
|
||||
from backend import utils
|
||||
|
||||
|
||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
matched_guesses = [model_list.SDXL]
|
||||
@@ -131,3 +134,19 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||
return sample.to(x)
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
model_list.SDXL.process_clip_state_dict_for_saving(self,
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||
)
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
Reference in New Issue
Block a user