mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-30 11:11:15 +00:00
update for Save Checkpoint button (#2636)
function in Checkpoint Merger previously produced unusable checkpoints for non-Flux architectures because keys had names not recognised by the model loader
This commit is contained in:
@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
|||||||
from backend.args import dynamic_args
|
from backend.args import dynamic_args
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
|
|
||||||
|
import safetensors.torch as sf
|
||||||
|
from backend import utils
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion(ForgeDiffusionEngine):
|
class StableDiffusion(ForgeDiffusionEngine):
|
||||||
matched_guesses = [model_list.SD15]
|
matched_guesses = [model_list.SD15]
|
||||||
@@ -79,3 +82,19 @@ class StableDiffusion(ForgeDiffusionEngine):
|
|||||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
return sample.to(x)
|
return sample.to(x)
|
||||||
|
|
||||||
|
def save_checkpoint(self, filename):
|
||||||
|
sd = {}
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
model_list.SD15.process_clip_state_dict_for_saving(self,
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||||
|
)
|
||||||
|
sf.save_file(sd, filename)
|
||||||
|
return filename
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from backend.text_processing.classic_engine import ClassicTextProcessingEngine
|
|||||||
from backend.args import dynamic_args
|
from backend.args import dynamic_args
|
||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
|
|
||||||
|
import safetensors.torch as sf
|
||||||
|
from backend import utils
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion2(ForgeDiffusionEngine):
|
class StableDiffusion2(ForgeDiffusionEngine):
|
||||||
matched_guesses = [model_list.SD20]
|
matched_guesses = [model_list.SD20]
|
||||||
@@ -79,3 +82,19 @@ class StableDiffusion2(ForgeDiffusionEngine):
|
|||||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
return sample.to(x)
|
return sample.to(x)
|
||||||
|
|
||||||
|
def save_checkpoint(self, filename):
|
||||||
|
sd = {}
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
model_list.SD20.process_clip_state_dict_for_saving(self,
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||||
|
)
|
||||||
|
sf.save_file(sd, filename)
|
||||||
|
return filename
|
||||||
|
|||||||
@@ -10,6 +10,9 @@ from backend.args import dynamic_args
|
|||||||
from backend import memory_management
|
from backend import memory_management
|
||||||
from backend.nn.unet import Timestep
|
from backend.nn.unet import Timestep
|
||||||
|
|
||||||
|
import safetensors.torch as sf
|
||||||
|
from backend import utils
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionXL(ForgeDiffusionEngine):
|
class StableDiffusionXL(ForgeDiffusionEngine):
|
||||||
matched_guesses = [model_list.SDXL]
|
matched_guesses = [model_list.SDXL]
|
||||||
@@ -131,3 +134,19 @@ class StableDiffusionXL(ForgeDiffusionEngine):
|
|||||||
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
sample = self.forge_objects.vae.first_stage_model.process_out(x)
|
||||||
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0
|
||||||
return sample.to(x)
|
return sample.to(x)
|
||||||
|
|
||||||
|
def save_checkpoint(self, filename):
|
||||||
|
sd = {}
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
model_list.SDXL.process_clip_state_dict_for_saving(self,
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sd.update(
|
||||||
|
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='first_stage_model.')
|
||||||
|
)
|
||||||
|
sf.save_file(sd, filename)
|
||||||
|
return filename
|
||||||
|
|||||||
Reference in New Issue
Block a user