mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-03-14 01:19:49 +00:00
add a way to save models
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import torch
|
||||
import safetensors.torch as sf
|
||||
|
||||
from backend import utils
|
||||
|
||||
|
||||
class ForgeObjects:
|
||||
@@ -63,3 +66,22 @@ class ForgeDiffusionEngine:
|
||||
self.is_sdxl = False
|
||||
self.is_sd3 = False
|
||||
return
|
||||
|
||||
def save_unet(self, filename):
|
||||
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
Reference in New Issue
Block a user