add a way to save models

This commit is contained in:
layerdiffusion
2024-08-19 06:30:49 -07:00
parent 4e8ba14dd0
commit 96f264ec6a
3 changed files with 68 additions and 1 deletions

View File

@@ -1,4 +1,7 @@
import torch
import safetensors.torch as sf
from backend import utils
class ForgeObjects:
@@ -63,3 +66,22 @@ class ForgeDiffusionEngine:
self.is_sdxl = False
self.is_sd3 = False
return
def save_unet(self, filename):
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
sf.save_file(sd, filename)
return filename
def save_checkpoint(self, filename):
sd = {}
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
)
sd.update(
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
)
sf.save_file(sd, filename)
return filename