mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-04-29 10:41:25 +00:00
add a way to save models
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
import torch
|
||||
import safetensors.torch as sf
|
||||
|
||||
from backend import utils
|
||||
|
||||
|
||||
class ForgeObjects:
|
||||
@@ -63,3 +66,22 @@ class ForgeDiffusionEngine:
|
||||
self.is_sdxl = False
|
||||
self.is_sd3 = False
|
||||
return
|
||||
|
||||
def save_unet(self, filename):
|
||||
sd = utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
def save_checkpoint(self, filename):
|
||||
sd = {}
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.unet.model.diffusion_model, prefix='model.diffusion_model.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.clip.cond_stage_model, prefix='text_encoders.')
|
||||
)
|
||||
sd.update(
|
||||
utils.get_state_dict_after_quant(self.forge_objects.vae.first_stage_model, prefix='vae.')
|
||||
)
|
||||
sf.save_file(sd, filename)
|
||||
return filename
|
||||
|
||||
@@ -138,3 +138,16 @@ def nested_move_to_device(obj, device):
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return obj.to(device)
|
||||
return obj
|
||||
|
||||
|
||||
def get_state_dict_after_quant(model, prefix=''):
|
||||
for m in model.modules():
|
||||
if hasattr(m, 'weight') and hasattr(m.weight, 'bnb_quantized'):
|
||||
if not m.weight.bnb_quantized:
|
||||
original_device = m.weight.device
|
||||
m.cuda()
|
||||
m.to(original_device)
|
||||
|
||||
sd = model.state_dict()
|
||||
sd = {(prefix + k): v.clone() for k, v in sd.items()}
|
||||
return sd
|
||||
|
||||
Reference in New Issue
Block a user