diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 17bb4863..e2e62b94 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1584,6 +1584,8 @@ class SDTrainer(BaseSDTrainProcess): self.scaler.update() self.optimizer.zero_grad(set_to_none=True) + if self.adapter and isinstance(self.adapter, CustomAdapter): + self.adapter.post_weight_update() if self.ema is not None: with self.timer('ema_update'): self.ema.update() diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 7a2ee487..75fac738 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module): if 'sv_adapter' in state_dict: self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict) - if 'vision_encoder' in state_dict and self.config.train_image_encoder: + if 'vision_encoder' in state_dict: self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict) if 'fuse_module' in state_dict: @@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module): for attn_processor in self.te_adapter.adapter_modules: yield from attn_processor.parameters(recurse) elif self.config.type == 'vision_direct': - for attn_processor in self.vd_adapter.adapter_modules: - yield from attn_processor.parameters(recurse) - if self.config.train_image_encoder: - yield from self.vision_encoder.parameters(recurse) + if self.config.train_scaler: + # only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules) + yield self.vd_adapter.block_scaler + else: + for attn_processor in self.vd_adapter.adapter_modules: + yield from attn_processor.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) elif self.config.type == 'te_augmenter': yield from self.te_augmenter.parameters(recurse) if self.config.train_image_encoder: @@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module): additional[k] = v additional['clip_layer'] = self.config.clip_layer additional['image_encoder_arch'] = self.config.head_dim - return additional \ No newline at end of file + return additional + + def post_weight_update(self): + # do any kind of updates after the weight update + if self.config.type == 'vision_direct': + self.vd_adapter.post_weight_update() + pass \ No newline at end of file diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index c180c18b..8755a040 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -5,9 +5,12 @@ import torch.nn as nn import torch.nn.functional as F import weakref from typing import Union, TYPE_CHECKING, Optional +from collections import OrderedDict from diffusers import Transformer2DModel, FluxTransformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection + +from toolkit.config_modules import AdapterConfig from toolkit.paths import REPOS_ROOT sys.path.append(REPOS_ROOT) @@ -286,7 +289,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, adapter=None, - adapter_hidden_size=None, has_bias=False, **kwargs): + adapter_hidden_size=None, has_bias=False, block_idx=0, **kwargs): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -298,6 +301,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): self.adapter_hidden_size = adapter_hidden_size self.cross_attention_dim = cross_attention_dim self.scale = scale + self.block_idx = block_idx self.to_k_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) self.to_v_adapter = nn.Linear(adapter_hidden_size, hidden_size, bias=has_bias) @@ -382,6 +386,10 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): # begin ip adapter if self.is_active and self.conditional_embeds is not None: adapter_hidden_states = self.conditional_embeds + block_scaler = self.adapter_ref().block_scaler + if block_scaler is not None: + block_scaler = block_scaler[self.block_idx] + if adapter_hidden_states.shape[0] < batch_size: adapter_hidden_states = torch.cat([ self.unconditional_embeds, @@ -407,6 +415,15 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) vd_hidden_states = vd_hidden_states.to(query.dtype) + # scale to block scaler + if block_scaler is not None: + orig_dtype = vd_hidden_states.dtype + if block_scaler.dtype != vd_hidden_states.dtype: + vd_hidden_states = vd_hidden_states.to(block_scaler.dtype) + vd_hidden_states = vd_hidden_states * block_scaler + if block_scaler.dtype != orig_dtype: + vd_hidden_states = vd_hidden_states.to(orig_dtype) + hidden_states = hidden_states + self.scale * vd_hidden_states if encoder_hidden_states is not None: @@ -437,6 +454,7 @@ class VisionDirectAdapter(torch.nn.Module): is_flux = sd.is_flux self.adapter_ref: weakref.ref = weakref.ref(adapter) self.sd_ref: weakref.ref = weakref.ref(sd) + self.config: AdapterConfig = adapter.config self.vision_model_ref: weakref.ref = weakref.ref(vision_model) if adapter.config.clip_layer == "image_embeds": @@ -469,6 +487,8 @@ class VisionDirectAdapter(torch.nn.Module): else: attn_processor_keys = list(sd.unet.attn_processors.keys()) + current_idx = 0 + for name in attn_processor_keys: if is_flux: cross_attention_dim = None @@ -550,6 +570,7 @@ class VisionDirectAdapter(torch.nn.Module): adapter=self, adapter_hidden_size=self.token_size, has_bias=False, + block_idx=current_idx ) else: attn_procs[name] = VisionDirectAdapterAttnProcessor( @@ -560,7 +581,9 @@ class VisionDirectAdapter(torch.nn.Module): adapter_hidden_size=self.token_size, has_bias=False, ) + current_idx += 1 attn_procs[name].load_state_dict(weights) + if self.sd_ref().is_pixart: # we have to set them ourselves transformer: Transformer2DModel = sd.unet @@ -595,14 +618,25 @@ class VisionDirectAdapter(torch.nn.Module): sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) - # # add the mlp layer - # self.mlp = MLP( - # in_dim=self.token_size, - # out_dim=self.token_size, - # hidden_dim=self.token_size, - # # dropout=0.1, - # use_residual=True - # ) + num_modules = len(self.adapter_modules) + if self.config.train_scaler: + self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules).to( + dtype=torch.float32, + device=self.sd_ref().device_torch + )) + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + self.block_scaler.requires_grad = True + else: + self.block_scaler = None + + def state_dict(self, destination=None, prefix='', keep_vars=False): + if self.config.train_scaler: + # only return the block scaler + if destination is None: + destination = OrderedDict() + destination[prefix + 'block_scaler'] = self.block_scaler + return destination + return super().state_dict(destination, prefix, keep_vars) # make a getter to see if is active @property @@ -610,5 +644,21 @@ class VisionDirectAdapter(torch.nn.Module): return self.adapter_ref().is_active def forward(self, input): - # return self.mlp(input) + # block scaler keeps moving dtypes. make sure it is float32 here + # todo remove this when we have a real solution + if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) return input + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + if self.block_scaler is not None: + if self.block_scaler.dtype != torch.float32: + self.block_scaler.data = self.block_scaler.data.to(torch.float32) + return self + + def post_weight_update(self): + # force block scaler to be mean of 1 + if self.block_scaler is not None: + self.block_scaler.data = self.block_scaler.data / self.block_scaler.data.mean() + pass