RAM optimization round 2

This commit is contained in:
AUTOMATIC1111
2023-08-16 09:55:35 +03:00
parent 85fcb7b8df
commit 86221269f9
2 changed files with 48 additions and 8 deletions

View File

@@ -304,7 +304,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)
weights_backup = getattr(self, "network_weights_backup", None)
if weights_backup is None:
if weights_backup is None and wanted_names != ():
if current_names != ():
raise RuntimeError("no backup weights found and current weights are not unchanged")
if isinstance(self, torch.nn.MultiheadAttention):
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
else: