mirror of
https://github.com/lllyasviel/stable-diffusion-webui-forge.git
synced 2026-02-02 14:27:27 +00:00
Previously before this commit, credits are already in entry and licenses are already in root. This commit will make info clearer.
90 lines
3.3 KiB
Python
90 lines
3.3 KiB
Python
# 1st edit by https://github.com/CompVis/latent-diffusion
|
|
# 2nd edit by https://github.com/Stability-AI/stablediffusion
|
|
# 3rd edit by https://github.com/Stability-AI/generative-models
|
|
# 4th edit by https://github.com/comfyanonymous/ComfyUI
|
|
|
|
|
|
# This file is not used in image diffusion backend.
|
|
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class LitEma(nn.Module):
|
|
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
|
super().__init__()
|
|
if decay < 0.0 or decay > 1.0:
|
|
raise ValueError('Decay must be between 0 and 1')
|
|
|
|
self.m_name2s_name = {}
|
|
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
|
else torch.tensor(-1, dtype=torch.int))
|
|
|
|
for name, p in model.named_parameters():
|
|
if p.requires_grad:
|
|
# remove as '.'-character is not allowed in buffers
|
|
s_name = name.replace('.', '')
|
|
self.m_name2s_name.update({name: s_name})
|
|
self.register_buffer(s_name, p.clone().detach().data)
|
|
|
|
self.collected_params = []
|
|
|
|
def reset_num_updates(self):
|
|
del self.num_updates
|
|
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
|
|
|
def forward(self, model):
|
|
decay = self.decay
|
|
|
|
if self.num_updates >= 0:
|
|
self.num_updates += 1
|
|
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
|
|
|
one_minus_decay = 1.0 - decay
|
|
|
|
with torch.no_grad():
|
|
m_param = dict(model.named_parameters())
|
|
shadow_params = dict(self.named_buffers())
|
|
|
|
for key in m_param:
|
|
if m_param[key].requires_grad:
|
|
sname = self.m_name2s_name[key]
|
|
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
|
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
|
else:
|
|
assert not key in self.m_name2s_name
|
|
|
|
def copy_to(self, model):
|
|
m_param = dict(model.named_parameters())
|
|
shadow_params = dict(self.named_buffers())
|
|
for key in m_param:
|
|
if m_param[key].requires_grad:
|
|
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
|
else:
|
|
assert not key in self.m_name2s_name
|
|
|
|
def store(self, parameters):
|
|
"""
|
|
Save the current parameters for restoring later.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
temporarily stored.
|
|
"""
|
|
self.collected_params = [param.clone() for param in parameters]
|
|
|
|
def restore(self, parameters):
|
|
"""
|
|
Restore the parameters stored with the `store` method.
|
|
Useful to validate the model with EMA parameters without affecting the
|
|
original optimization process. Store the parameters before the
|
|
`copy_to` method. After validation (or model saving), use this to
|
|
restore the former parameters.
|
|
Args:
|
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
updated with the stored parameters.
|
|
"""
|
|
for c_param, param in zip(self.collected_params, parameters):
|
|
param.data.copy_(c_param.data)
|