From 67c2e44edbe8ad4520e2f4ca377e5e0fea5edf00 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 21 Nov 2024 20:01:52 -0700 Subject: [PATCH] Added support for training flux redux adapters --- jobs/process/BaseSDTrainProcess.py | 7 +++- toolkit/custom_adapter.py | 58 ++++++++++++++++++++++++++++-- toolkit/models/redux.py | 26 ++++++++++++++ 3 files changed, 88 insertions(+), 3 deletions(-) create mode 100644 toolkit/models/redux.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 228f9491..f4c10cf3 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -513,12 +513,17 @@ class BaseSDTrainProcess(BaseTrainProcess): # move it back self.adapter = self.adapter.to(orig_device, dtype=orig_dtype) else: + direct_save = False + if self.adapter_config.train_only_image_encoder: + direct_save = True + if self.adapter_config.type == 'redux': + direct_save = True save_ip_adapter_from_diffusers( state_dict, output_file=file_path, meta=save_meta, dtype=get_torch_dtype(self.save_config.dtype), - direct_save=self.adapter_config.train_only_image_encoder + direct_save=direct_save ) else: if self.save_config.save_format == "diffusers": diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 28fdd78a..6ca66020 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -14,6 +14,7 @@ from toolkit.models.single_value_adapter import SingleValueAdapter from toolkit.models.te_adapter import TEAdapter from toolkit.models.te_aug_adapter import TEAugAdapter from toolkit.models.vd_adapter import VisionDirectAdapter +from toolkit.models.redux import ReduxImageEncoder from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model @@ -93,6 +94,8 @@ class CustomAdapter(torch.nn.Module): self.te_augmenter: TEAugAdapter = None self.vd_adapter: VisionDirectAdapter = None self.single_value_adapter: SingleValueAdapter = None + self.redux_adapter: ReduxImageEncoder = None + self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -201,6 +204,8 @@ class CustomAdapter(torch.nn.Module): self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder) elif self.adapter_type == 'single_value': self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens) + elif self.adapter_type == 'redux': + self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype) else: raise ValueError(f"unknown adapter type: {self.adapter_type}") @@ -423,6 +428,13 @@ class CustomAdapter(torch.nn.Module): self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict) except Exception as e: print(e) + if 'redux_up' in state_dict: + # state dict is seperated. so recombine it + new_dict = {} + for k, v in state_dict.items(): + for k2, v2 in v.items(): + new_dict[k + '.' + k2] = v2 + self.redux_adapter.load_state_dict(new_dict, strict=True) pass @@ -466,6 +478,11 @@ class CustomAdapter(torch.nn.Module): state_dict["vision_encoder"] = self.vision_encoder.state_dict() state_dict["ilora"] = self.ilora_module.state_dict() return state_dict + elif self.adapter_type == 'redux': + d = self.redux_adapter.state_dict() + for k, v in d.items(): + state_dict[k] = v + return state_dict else: raise NotImplementedError @@ -482,7 +499,7 @@ class CustomAdapter(torch.nn.Module): prompt: Union[List[str], str], is_unconditional: bool = False, ): - if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct': + if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux': return prompt elif self.adapter_type == 'text_encoder': # todo allow for training @@ -604,7 +621,7 @@ class CustomAdapter(torch.nn.Module): if self.adapter_type == 'ilora': return prompt_embeds - if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': + if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux': if is_unconditional: # we dont condition the negative embeds for photo maker return prompt_embeds.clone() @@ -626,6 +643,7 @@ class CustomAdapter(torch.nn.Module): return_tensors="pt", do_resize=True, do_rescale=False, + do_convert_rgb=True ).pixel_values else: clip_image = tensors_0_1 @@ -706,7 +724,41 @@ class CustomAdapter(torch.nn.Module): ) return prompt_embeds + elif self.adapter_type == 'redux': + with torch.set_grad_enabled(is_training): + if is_training and self.config.train_image_encoder: + self.vision_encoder.train() + clip_image = clip_image.requires_grad_(True) + id_embeds = self.vision_encoder( + clip_image, + output_hidden_states=True, + ) + else: + with torch.no_grad(): + self.vision_encoder.eval() + id_embeds = self.vision_encoder( + clip_image, output_hidden_states=True + ) + img_embeds = id_embeds['last_hidden_state'] + + if self.config.quad_image: + # get the outputs of the quat + chunks = img_embeds.chunk(quad_count, dim=0) + chunk_sum = torch.zeros_like(chunks[0]) + for chunk in chunks: + chunk_sum = chunk_sum + chunk + # get the mean of them + + img_embeds = chunk_sum / quad_count + + if not is_training or not self.config.train_image_encoder: + img_embeds = img_embeds.detach() + + img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype))) + + prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2) + return prompt_embeds else: return prompt_embeds @@ -945,6 +997,8 @@ class CustomAdapter(torch.nn.Module): yield from self.vision_encoder.parameters(recurse) elif self.config.type == 'single_value': yield from self.single_value_adapter.parameters(recurse) + elif self.config.type == 'redux': + yield from self.redux_adapter.parameters(recurse) else: raise NotImplementedError diff --git a/toolkit/models/redux.py b/toolkit/models/redux.py new file mode 100644 index 00000000..609ac50a --- /dev/null +++ b/toolkit/models/redux.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + + +class ReduxImageEncoder(torch.nn.Module): + def __init__( + self, + redux_dim: int = 1152, + txt_in_features: int = 4096, + device=None, + dtype=None, + ) -> None: + super().__init__() + self.redux_dim = redux_dim + self.device = device + self.dtype = dtype + self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) + self.redux_down = nn.Linear( + txt_in_features * 3, txt_in_features, dtype=dtype) + + def forward(self, sigclip_embeds) -> torch.Tensor: + x = self.redux_up(sigclip_embeds) + x = torch.nn.functional.silu(x) + + projected_x = self.redux_down(x) + return projected_x