mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
added ema
This commit is contained in:
@@ -29,6 +29,7 @@ import gc
|
||||
import torch
|
||||
from jobs.process import BaseSDTrainProcess
|
||||
from torchvision import transforms
|
||||
from diffusers import EMAModel
|
||||
import math
|
||||
|
||||
|
||||
@@ -1510,6 +1511,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
# self.scaler.update()
|
||||
# self.optimizer.step()
|
||||
self.optimizer.zero_grad(set_to_none=True)
|
||||
if self.ema is not None:
|
||||
with self.timer('ema_update'):
|
||||
self.ema.update()
|
||||
else:
|
||||
# gradient accumulation. Just a place for breakpoint
|
||||
pass
|
||||
|
||||
@@ -22,6 +22,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.ema import ExponentialMovingAverage
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
@@ -174,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
self.snr_gos: Union[LearnableSNRGamma, None] = None
|
||||
self.ema: ExponentialMovingAverage = None
|
||||
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -253,9 +255,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# post process
|
||||
gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list)
|
||||
|
||||
# if we have an ema, set it to validation mode
|
||||
if self.ema is not None:
|
||||
self.ema.eval()
|
||||
|
||||
# send to be generated
|
||||
self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.train()
|
||||
|
||||
def update_training_metadata(self):
|
||||
o_dict = OrderedDict({
|
||||
"training_info": self.get_training_info()
|
||||
@@ -369,6 +378,10 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
def save(self, step=None):
|
||||
flush()
|
||||
if self.ema is not None:
|
||||
# always save params as ema
|
||||
self.ema.eval()
|
||||
|
||||
if not os.path.exists(self.save_root):
|
||||
os.makedirs(self.save_root, exist_ok=True)
|
||||
|
||||
@@ -527,6 +540,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.print(f"Saved to {file_path}")
|
||||
self.clean_up_saves()
|
||||
self.post_save_hook(file_path)
|
||||
|
||||
if self.ema is not None:
|
||||
self.ema.train()
|
||||
flush()
|
||||
|
||||
# Called before the model is loaded
|
||||
@@ -541,6 +557,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
def hook_before_train_loop(self):
|
||||
pass
|
||||
|
||||
def setup_ema(self):
|
||||
if self.train_config.ema_config.use_ema:
|
||||
# our params are in groups. We need them as a single iterable
|
||||
params = []
|
||||
for group in self.optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
params.append(param)
|
||||
self.ema = ExponentialMovingAverage(
|
||||
params,
|
||||
self.train_config.ema_config.ema_decay
|
||||
)
|
||||
|
||||
def before_dataset_load(self):
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING
|
||||
from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict
|
||||
import random
|
||||
|
||||
import torch
|
||||
@@ -133,6 +133,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'cont
|
||||
|
||||
CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state']
|
||||
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net
|
||||
@@ -248,7 +249,7 @@ class TrainConfig:
|
||||
self.start_step = kwargs.get('start_step', None)
|
||||
self.free_u = kwargs.get('free_u', False)
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||
self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||
@@ -331,6 +332,13 @@ class TrainConfig:
|
||||
# applies negative loss on the prior to encourage network to diverge from it
|
||||
self.do_prior_divergence = kwargs.get('do_prior_divergence', False)
|
||||
|
||||
ema_config: Union[Dict, None] = kwargs.get('ema_config', None)
|
||||
if ema_config is not None:
|
||||
ema_config['use_ema'] = True
|
||||
else:
|
||||
ema_config = {'use_ema': False}
|
||||
|
||||
self.ema_config: EMAConfig = EMAConfig(**ema_config)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@@ -376,6 +384,12 @@ class ModelConfig:
|
||||
self.unet_sample_size = kwargs.get("unet_sample_size", None)
|
||||
|
||||
|
||||
class EMAConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.use_ema: bool = kwargs.get('use_ema', False)
|
||||
self.ema_decay: float = kwargs.get('ema_decay', 0.999)
|
||||
|
||||
|
||||
class ReferenceDatasetConfig:
|
||||
def __init__(self, **kwargs):
|
||||
# can pass with a side by side pait or a folder with pos and neg folder
|
||||
@@ -483,7 +497,7 @@ class DatasetConfig:
|
||||
self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0))
|
||||
self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False)
|
||||
self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0))
|
||||
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
|
||||
self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped
|
||||
self.flip_x: bool = kwargs.get('flip_x', False)
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
|
||||
318
toolkit/ema.py
Normal file
318
toolkit/ema.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from __future__ import division
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from typing import Iterable, Optional
|
||||
import weakref
|
||||
import copy
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Partially based on:
|
||||
# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
|
||||
class ExponentialMovingAverage:
|
||||
"""
|
||||
Maintains (exponential) moving average of a set of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter` (typically from
|
||||
`model.parameters()`).
|
||||
Note that EMA is computed on *all* provided parameters,
|
||||
regardless of whether or not they have `requires_grad = True`;
|
||||
this allows a single EMA object to be consistantly used even
|
||||
if which parameters are trainable changes step to step.
|
||||
|
||||
If you want to some parameters in the EMA, do not pass them
|
||||
to the object in the first place. For example:
|
||||
|
||||
ExponentialMovingAverage(
|
||||
parameters=[p for p in model.parameters() if p.requires_grad],
|
||||
decay=0.9
|
||||
)
|
||||
|
||||
will ignore parameters that do not require grad.
|
||||
|
||||
decay: The exponential decay.
|
||||
|
||||
use_num_updates: Whether to use number of updates when computing
|
||||
averages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parameters: Iterable[torch.nn.Parameter] = None,
|
||||
decay: float = 0.995,
|
||||
use_num_updates: bool = True
|
||||
):
|
||||
if parameters is None:
|
||||
raise ValueError("parameters must be provided")
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.decay = decay
|
||||
self.num_updates = 0 if use_num_updates else None
|
||||
parameters = list(parameters)
|
||||
self.shadow_params = [
|
||||
p.clone().detach()
|
||||
for p in parameters
|
||||
]
|
||||
self.collected_params = None
|
||||
self._is_train_mode = True
|
||||
# By maintaining only a weakref to each parameter,
|
||||
# we maintain the old GC behaviour of ExponentialMovingAverage:
|
||||
# if the model goes out of scope but the ExponentialMovingAverage
|
||||
# is kept, no references to the model or its parameters will be
|
||||
# maintained, and the model will be cleaned up.
|
||||
self._params_refs = [weakref.ref(p) for p in parameters]
|
||||
|
||||
def _get_parameters(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]]
|
||||
) -> Iterable[torch.nn.Parameter]:
|
||||
if parameters is None:
|
||||
parameters = [p() for p in self._params_refs]
|
||||
if any(p is None for p in parameters):
|
||||
raise ValueError(
|
||||
"(One of) the parameters with which this "
|
||||
"ExponentialMovingAverage "
|
||||
"was initialized no longer exists (was garbage collected);"
|
||||
" please either provide `parameters` explicitly or keep "
|
||||
"the model to which they belong from being garbage "
|
||||
"collected."
|
||||
)
|
||||
return parameters
|
||||
else:
|
||||
parameters = list(parameters)
|
||||
if len(parameters) != len(self.shadow_params):
|
||||
raise ValueError(
|
||||
"Number of parameters passed as argument is different "
|
||||
"from number of shadow parameters maintained by this "
|
||||
"ExponentialMovingAverage"
|
||||
)
|
||||
return parameters
|
||||
|
||||
def update(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Update currently maintained parameters.
|
||||
|
||||
Call this every time the parameters are updated, such as the result of
|
||||
the `optimizer.step()` call.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
||||
parameters used to initialize this object. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
parameters = self._get_parameters(parameters)
|
||||
decay = self.decay
|
||||
if self.num_updates is not None:
|
||||
self.num_updates += 1
|
||||
decay = min(
|
||||
decay,
|
||||
(1 + self.num_updates) / (10 + self.num_updates)
|
||||
)
|
||||
one_minus_decay = 1.0 - decay
|
||||
with torch.no_grad():
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
tmp = (s_param - param)
|
||||
# tmp will be a new tensor so we can do in-place
|
||||
tmp.mul_(one_minus_decay)
|
||||
s_param.sub_(tmp)
|
||||
|
||||
def copy_to(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Copy current averaged parameters into given collection of parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored moving averages. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
parameters = self._get_parameters(parameters)
|
||||
for s_param, param in zip(self.shadow_params, parameters):
|
||||
param.data.copy_(s_param.data)
|
||||
|
||||
def store(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored. If `None`, the parameters of with which this
|
||||
`ExponentialMovingAverage` was initialized will be used.
|
||||
"""
|
||||
parameters = self._get_parameters(parameters)
|
||||
self.collected_params = [
|
||||
param.clone()
|
||||
for param in parameters
|
||||
]
|
||||
|
||||
def restore(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
if self.collected_params is None:
|
||||
raise RuntimeError(
|
||||
"This ExponentialMovingAverage has no `store()`ed weights "
|
||||
"to `restore()`"
|
||||
)
|
||||
parameters = self._get_parameters(parameters)
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def average_parameters(
|
||||
self,
|
||||
parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
||||
):
|
||||
r"""
|
||||
Context manager for validation/inference with averaged parameters.
|
||||
|
||||
Equivalent to:
|
||||
|
||||
ema.store()
|
||||
ema.copy_to()
|
||||
try:
|
||||
...
|
||||
finally:
|
||||
ema.restore()
|
||||
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters. If `None`, the
|
||||
parameters with which this `ExponentialMovingAverage` was
|
||||
initialized will be used.
|
||||
"""
|
||||
parameters = self._get_parameters(parameters)
|
||||
self.store(parameters)
|
||||
self.copy_to(parameters)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.restore(parameters)
|
||||
|
||||
def to(self, device=None, dtype=None) -> None:
|
||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||
|
||||
Args:
|
||||
device: like `device` argument to `torch.Tensor.to`
|
||||
"""
|
||||
# .to() on the tensors handles None correctly
|
||||
self.shadow_params = [
|
||||
p.to(device=device, dtype=dtype)
|
||||
if p.is_floating_point()
|
||||
else p.to(device=device)
|
||||
for p in self.shadow_params
|
||||
]
|
||||
if self.collected_params is not None:
|
||||
self.collected_params = [
|
||||
p.to(device=device, dtype=dtype)
|
||||
if p.is_floating_point()
|
||||
else p.to(device=device)
|
||||
for p in self.collected_params
|
||||
]
|
||||
return
|
||||
|
||||
def state_dict(self) -> dict:
|
||||
r"""Returns the state of the ExponentialMovingAverage as a dict."""
|
||||
# Following PyTorch conventions, references to tensors are returned:
|
||||
# "returns a reference to the state and not its copy!" -
|
||||
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
||||
return {
|
||||
"decay": self.decay,
|
||||
"num_updates": self.num_updates,
|
||||
"shadow_params": self.shadow_params,
|
||||
"collected_params": self.collected_params
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict: dict) -> None:
|
||||
r"""Loads the ExponentialMovingAverage state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): EMA state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
# deepcopy, to be consistent with module API
|
||||
state_dict = copy.deepcopy(state_dict)
|
||||
self.decay = state_dict["decay"]
|
||||
if self.decay < 0.0 or self.decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
self.num_updates = state_dict["num_updates"]
|
||||
assert self.num_updates is None or isinstance(self.num_updates, int), \
|
||||
"Invalid num_updates"
|
||||
|
||||
self.shadow_params = state_dict["shadow_params"]
|
||||
assert isinstance(self.shadow_params, list), \
|
||||
"shadow_params must be a list"
|
||||
assert all(
|
||||
isinstance(p, torch.Tensor) for p in self.shadow_params
|
||||
), "shadow_params must all be Tensors"
|
||||
|
||||
self.collected_params = state_dict["collected_params"]
|
||||
if self.collected_params is not None:
|
||||
assert isinstance(self.collected_params, list), \
|
||||
"collected_params must be a list"
|
||||
assert all(
|
||||
isinstance(p, torch.Tensor) for p in self.collected_params
|
||||
), "collected_params must all be Tensors"
|
||||
assert len(self.collected_params) == len(self.shadow_params), \
|
||||
"collected_params and shadow_params had different lengths"
|
||||
|
||||
if len(self.shadow_params) == len(self._params_refs):
|
||||
# Consistant with torch.optim.Optimizer, cast things to consistant
|
||||
# device and dtype with the parameters
|
||||
params = [p() for p in self._params_refs]
|
||||
# If parameters have been garbage collected, just load the state
|
||||
# we were given without change.
|
||||
if not any(p is None for p in params):
|
||||
# ^ parameter references are still good
|
||||
for i, p in enumerate(params):
|
||||
self.shadow_params[i] = self.shadow_params[i].to(
|
||||
device=p.device, dtype=p.dtype
|
||||
)
|
||||
if self.collected_params is not None:
|
||||
self.collected_params[i] = self.collected_params[i].to(
|
||||
device=p.device, dtype=p.dtype
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Tried to `load_state_dict()` with the wrong number of "
|
||||
"parameters in the saved state."
|
||||
)
|
||||
|
||||
def eval(self):
|
||||
if self._is_train_mode:
|
||||
with torch.no_grad():
|
||||
self.store()
|
||||
self.copy_to()
|
||||
self._is_train_mode = False
|
||||
|
||||
def train(self):
|
||||
if not self._is_train_mode:
|
||||
with torch.no_grad():
|
||||
self.restore()
|
||||
self._is_train_mode = True
|
||||
Reference in New Issue
Block a user