added ema

This commit is contained in:
Jaret Burkett
2024-06-28 10:03:26 -06:00
parent 657fd09f25
commit 603ceca3ca
4 changed files with 367 additions and 3 deletions

View File

@@ -29,6 +29,7 @@ import gc
import torch
from jobs.process import BaseSDTrainProcess
from torchvision import transforms
from diffusers import EMAModel
import math
@@ -1510,6 +1511,9 @@ class SDTrainer(BaseSDTrainProcess):
# self.scaler.update()
# self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.ema is not None:
with self.timer('ema_update'):
self.ema.update()
else:
# gradient accumulation. Just a place for breakpoint
pass

View File

@@ -22,6 +22,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.custom_adapter import CustomAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.ema import ExponentialMovingAverage
from toolkit.embedding import Embedding
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter
@@ -174,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -253,9 +255,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
# post process
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
# if we have an ema, set it to validation mode
if self.ema is not None:
self.ema.eval()
# send to be generated
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
if self.ema is not None:
self.ema.train()
def update_training_metadata(self):
o_dict = OrderedDict({
"training_info": self.get_training_info()
@@ -369,6 +378,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
def save(self, step=None):
flush()
if self.ema is not None:
# always save params as ema
self.ema.eval()
if not os.path.exists(self.save_root):
os.makedirs(self.save_root, exist_ok=True)
@@ -527,6 +540,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.print(f"Saved to {file_path}")
self.clean_up_saves()
self.post_save_hook(file_path)
if self.ema is not None:
self.ema.train()
flush()
# Called before the model is loaded
@@ -541,6 +557,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
def hook_before_train_loop(self):
pass
def setup_ema(self):
if self.train_config.ema_config.use_ema:
# our params are in groups. We need them as a single iterable
params = []
for group in self.optimizer.param_groups:
for param in group['params']:
params.append(param)
self.ema = ExponentialMovingAverage(
params,
self.train_config.ema_config.ema_decay
)
def before_dataset_load(self):
pass

View File

@@ -1,6 +1,6 @@
import os
import time
from typing import List, Optional, Literal, Union, TYPE_CHECKING
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
import random
import torch
@@ -133,6 +133,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'cont
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
@@ -248,7 +249,7 @@ class TrainConfig:
self.start_step = kwargs.get('start_step', None)
self.free_u = kwargs.get('free_u', False)
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
@@ -331,6 +332,13 @@ class TrainConfig:
# applies negative loss on the prior to encourage network to diverge from it
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
if ema_config is not None:
ema_config['use_ema'] = True
else:
ema_config = {'use_ema': False}
self.ema_config: EMAConfig = EMAConfig(**ema_config)
class ModelConfig:
@@ -376,6 +384,12 @@ class ModelConfig:
self.unet_sample_size = kwargs.get("unet_sample_size", None)
class EMAConfig:
def __init__(self, **kwargs):
self.use_ema: bool = kwargs.get('use_ema', False)
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
class ReferenceDatasetConfig:
def __init__(self, **kwargs):
# can pass with a side by side pait or a folder with pos and neg folder
@@ -483,7 +497,7 @@ class DatasetConfig:
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
self.flip_x: bool = kwargs.get('flip_x', False)
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])

318
toolkit/ema.py Normal file
View File

@@ -0,0 +1,318 @@
from __future__ import division
from __future__ import unicode_literals
from typing import Iterable, Optional
import weakref
import copy
import contextlib
import torch
# Partially based on:
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter` (typically from
`model.parameters()`).
Note that EMA is computed on *all* provided parameters,
regardless of whether or not they have `requires_grad = True`;
this allows a single EMA object to be consistantly used even
if which parameters are trainable changes step to step.
If you want to some parameters in the EMA, do not pass them
to the object in the first place. For example:
ExponentialMovingAverage(
parameters=[p for p in model.parameters() if p.requires_grad],
decay=0.9
)
will ignore parameters that do not require grad.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
def __init__(
self,
parameters: Iterable[torch.nn.Parameter] = None,
decay: float = 0.995,
use_num_updates: bool = True
):
if parameters is None:
raise ValueError("parameters must be provided")
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
parameters = list(parameters)
self.shadow_params = [
p.clone().detach()
for p in parameters
]
self.collected_params = None
self._is_train_mode = True
# By maintaining only a weakref to each parameter,
# we maintain the old GC behaviour of ExponentialMovingAverage:
# if the model goes out of scope but the ExponentialMovingAverage
# is kept, no references to the model or its parameters will be
# maintained, and the model will be cleaned up.
self._params_refs = [weakref.ref(p) for p in parameters]
def _get_parameters(
self,
parameters: Optional[Iterable[torch.nn.Parameter]]
) -> Iterable[torch.nn.Parameter]:
if parameters is None:
parameters = [p() for p in self._params_refs]
if any(p is None for p in parameters):
raise ValueError(
"(One of) the parameters with which this "
"ExponentialMovingAverage "
"was initialized no longer exists (was garbage collected);"
" please either provide `parameters` explicitly or keep "
"the model to which they belong from being garbage "
"collected."
)
return parameters
else:
parameters = list(parameters)
if len(parameters) != len(self.shadow_params):
raise ValueError(
"Number of parameters passed as argument is different "
"from number of shadow parameters maintained by this "
"ExponentialMovingAverage"
)
return parameters
def update(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(
decay,
(1 + self.num_updates) / (10 + self.num_updates)
)
one_minus_decay = 1.0 - decay
with torch.no_grad():
for s_param, param in zip(self.shadow_params, parameters):
tmp = (s_param - param)
# tmp will be a new tensor so we can do in-place
tmp.mul_(one_minus_decay)
s_param.sub_(tmp)
def copy_to(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
def store(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored. If `None`, the parameters of with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.collected_params = [
param.clone()
for param in parameters
]
def restore(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
) -> None:
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
if self.collected_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
parameters = self._get_parameters(parameters)
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
@contextlib.contextmanager
def average_parameters(
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None
):
r"""
Context manager for validation/inference with averaged parameters.
Equivalent to:
ema.store()
ema.copy_to()
try:
...
finally:
ema.restore()
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = self._get_parameters(parameters)
self.store(parameters)
self.copy_to(parameters)
try:
yield
finally:
self.restore(parameters)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.shadow_params
]
if self.collected_params is not None:
self.collected_params = [
p.to(device=device, dtype=dtype)
if p.is_floating_point()
else p.to(device=device)
for p in self.collected_params
]
return
def state_dict(self) -> dict:
r"""Returns the state of the ExponentialMovingAverage as a dict."""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"num_updates": self.num_updates,
"shadow_params": self.shadow_params,
"collected_params": self.collected_params
}
def load_state_dict(self, state_dict: dict) -> None:
r"""Loads the ExponentialMovingAverage state.
Args:
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
self.decay = state_dict["decay"]
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.num_updates = state_dict["num_updates"]
assert self.num_updates is None or isinstance(self.num_updates, int), \
"Invalid num_updates"
self.shadow_params = state_dict["shadow_params"]
assert isinstance(self.shadow_params, list), \
"shadow_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.shadow_params
), "shadow_params must all be Tensors"
self.collected_params = state_dict["collected_params"]
if self.collected_params is not None:
assert isinstance(self.collected_params, list), \
"collected_params must be a list"
assert all(
isinstance(p, torch.Tensor) for p in self.collected_params
), "collected_params must all be Tensors"
assert len(self.collected_params) == len(self.shadow_params), \
"collected_params and shadow_params had different lengths"
if len(self.shadow_params) == len(self._params_refs):
# Consistant with torch.optim.Optimizer, cast things to consistant
# device and dtype with the parameters
params = [p() for p in self._params_refs]
# If parameters have been garbage collected, just load the state
# we were given without change.
if not any(p is None for p in params):
# ^ parameter references are still good
for i, p in enumerate(params):
self.shadow_params[i] = self.shadow_params[i].to(
device=p.device, dtype=p.dtype
)
if self.collected_params is not None:
self.collected_params[i] = self.collected_params[i].to(
device=p.device, dtype=p.dtype
)
else:
raise ValueError(
"Tried to `load_state_dict()` with the wrong number of "
"parameters in the saved state."
)
def eval(self):
if self._is_train_mode:
with torch.no_grad():
self.store()
self.copy_to()
self._is_train_mode = False
def train(self):
if not self._is_train_mode:
with torch.no_grad():
self.restore()
self._is_train_mode = True