diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3a0bca57..33f162e7 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -29,6 +29,7 @@ import gc import torch from jobs.process import BaseSDTrainProcess from torchvision import transforms +from diffusers import EMAModel import math @@ -1510,6 +1511,9 @@ class SDTrainer(BaseSDTrainProcess): # self.scaler.update() # self.optimizer.step() self.optimizer.zero_grad(set_to_none=True) + if self.ema is not None: + with self.timer('ema_update'): + self.ema.update() else: # gradient accumulation. Just a place for breakpoint pass diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index c93730e5..20d7d322 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -22,6 +22,7 @@ from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter @@ -174,6 +175,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.embed_config is not None or is_training_adapter: self.named_lora = True self.snr_gos: Union[LearnableSNRGamma, None] = None + self.ema: ExponentialMovingAverage = None def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -253,9 +255,16 @@ class BaseSDTrainProcess(BaseTrainProcess): # post process gen_img_config_list = self.post_process_generate_image_config_list(gen_img_config_list) + # if we have an ema, set it to validation mode + if self.ema is not None: + self.ema.eval() + # send to be generated self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler) + if self.ema is not None: + self.ema.train() + def update_training_metadata(self): o_dict = OrderedDict({ "training_info": self.get_training_info() @@ -369,6 +378,10 @@ class BaseSDTrainProcess(BaseTrainProcess): def save(self, step=None): flush() + if self.ema is not None: + # always save params as ema + self.ema.eval() + if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) @@ -527,6 +540,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.print(f"Saved to {file_path}") self.clean_up_saves() self.post_save_hook(file_path) + + if self.ema is not None: + self.ema.train() flush() # Called before the model is loaded @@ -541,6 +557,18 @@ class BaseSDTrainProcess(BaseTrainProcess): def hook_before_train_loop(self): pass + def setup_ema(self): + if self.train_config.ema_config.use_ema: + # our params are in groups. We need them as a single iterable + params = [] + for group in self.optimizer.param_groups: + for param in group['params']: + params.append(param) + self.ema = ExponentialMovingAverage( + params, + self.train_config.ema_config.ema_decay + ) + def before_dataset_load(self): pass diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 2fd1795b..deb6cecf 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -1,6 +1,6 @@ import os import time -from typing import List, Optional, Literal, Union, TYPE_CHECKING +from typing import List, Optional, Literal, Union, TYPE_CHECKING, Dict import random import torch @@ -133,6 +133,7 @@ AdapterTypes = Literal['t2i', 'ip', 'ip+', 'clip', 'ilora', 'photo_maker', 'cont CLIPLayer = Literal['penultimate_hidden_states', 'image_embeds', 'last_hidden_state'] + class AdapterConfig: def __init__(self, **kwargs): self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip, clip, control_net @@ -248,7 +249,7 @@ class TrainConfig: self.start_step = kwargs.get('start_step', None) self.free_u = kwargs.get('free_u', False) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) - self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net + self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) @@ -331,6 +332,13 @@ class TrainConfig: # applies negative loss on the prior to encourage network to diverge from it self.do_prior_divergence = kwargs.get('do_prior_divergence', False) + ema_config: Union[Dict, None] = kwargs.get('ema_config', None) + if ema_config is not None: + ema_config['use_ema'] = True + else: + ema_config = {'use_ema': False} + + self.ema_config: EMAConfig = EMAConfig(**ema_config) class ModelConfig: @@ -376,6 +384,12 @@ class ModelConfig: self.unet_sample_size = kwargs.get("unet_sample_size", None) +class EMAConfig: + def __init__(self, **kwargs): + self.use_ema: bool = kwargs.get('use_ema', False) + self.ema_decay: float = kwargs.get('ema_decay', 0.999) + + class ReferenceDatasetConfig: def __init__(self, **kwargs): # can pass with a side by side pait or a folder with pos and neg folder @@ -483,7 +497,7 @@ class DatasetConfig: self.token_dropout_rate: float = float(kwargs.get('token_dropout_rate', 0.0)) self.shuffle_tokens: bool = kwargs.get('shuffle_tokens', False) self.caption_dropout_rate: float = float(kwargs.get('caption_dropout_rate', 0.0)) - self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped + self.keep_tokens: int = kwargs.get('keep_tokens', 0) # #of first tokens to always keep unless caption dropped self.flip_x: bool = kwargs.get('flip_x', False) self.flip_y: bool = kwargs.get('flip_y', False) self.augments: List[str] = kwargs.get('augments', []) diff --git a/toolkit/ema.py b/toolkit/ema.py new file mode 100644 index 00000000..43eb8c8f --- /dev/null +++ b/toolkit/ema.py @@ -0,0 +1,318 @@ +from __future__ import division +from __future__ import unicode_literals + +from typing import Iterable, Optional +import weakref +import copy +import contextlib + +import torch + + +# Partially based on: +# https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter] = None, + decay: float = 0.995, + use_num_updates: bool = True + ): + if parameters is None: + raise ValueError("parameters must be provided") + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.decay = decay + self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) + self.shadow_params = [ + p.clone().detach() + for p in parameters + ] + self.collected_params = None + self._is_train_mode = True + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + + def _get_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this " + "ExponentialMovingAverage " + "was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep " + "the model to which they belong from being garbage " + "collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def update( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min( + decay, + (1 + self.num_updates) / (10 + self.num_updates) + ) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + tmp = (s_param - param) + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param.sub_(tmp) + + def copy_to( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [ + param.clone() + for param in parameters + ] + + def restore( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + self.restore(parameters) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance(self.num_updates, int), \ + "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), \ + "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance(self.collected_params, list), \ + "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len(self.shadow_params), \ + "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) + + def eval(self): + if self._is_train_mode: + with torch.no_grad(): + self.store() + self.copy_to() + self._is_train_mode = False + + def train(self): + if not self._is_train_mode: + with torch.no_grad(): + self.restore() + self._is_train_mode = True