Added support for training flux redux adapters

This commit is contained in:
Jaret Burkett
2024-11-21 20:01:52 -07:00
parent 96d418bb95
commit 67c2e44edb
3 changed files with 88 additions and 3 deletions

View File

@@ -513,12 +513,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
# move it back
self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
else:
direct_save = False
if self.adapter_config.train_only_image_encoder:
direct_save = True
if self.adapter_config.type == 'redux':
direct_save = True
save_ip_adapter_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype),
direct_save=self.adapter_config.train_only_image_encoder
direct_save=direct_save
)
else:
if self.save_config.save_format == "diffusers":

View File

@@ -14,6 +14,7 @@ from toolkit.models.single_value_adapter import SingleValueAdapter
from toolkit.models.te_adapter import TEAdapter
from toolkit.models.te_aug_adapter import TEAugAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter
from toolkit.models.redux import ReduxImageEncoder
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
@@ -93,6 +94,8 @@ class CustomAdapter(torch.nn.Module):
self.te_augmenter: TEAugAdapter = None
self.vd_adapter: VisionDirectAdapter = None
self.single_value_adapter: SingleValueAdapter = None
self.redux_adapter: ReduxImageEncoder = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -201,6 +204,8 @@ class CustomAdapter(torch.nn.Module):
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
elif self.adapter_type == 'single_value':
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
elif self.adapter_type == 'redux':
self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
@@ -423,6 +428,13 @@ class CustomAdapter(torch.nn.Module):
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
except Exception as e:
print(e)
if 'redux_up' in state_dict:
# state dict is seperated. so recombine it
new_dict = {}
for k, v in state_dict.items():
for k2, v2 in v.items():
new_dict[k + '.' + k2] = v2
self.redux_adapter.load_state_dict(new_dict, strict=True)
pass
@@ -466,6 +478,11 @@ class CustomAdapter(torch.nn.Module):
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["ilora"] = self.ilora_module.state_dict()
return state_dict
elif self.adapter_type == 'redux':
d = self.redux_adapter.state_dict()
for k, v in d.items():
state_dict[k] = v
return state_dict
else:
raise NotImplementedError
@@ -482,7 +499,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux':
return prompt
elif self.adapter_type == 'text_encoder':
# todo allow for training
@@ -604,7 +621,7 @@ class CustomAdapter(torch.nn.Module):
if self.adapter_type == 'ilora':
return prompt_embeds
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds.clone()
@@ -626,6 +643,7 @@ class CustomAdapter(torch.nn.Module):
return_tensors="pt",
do_resize=True,
do_rescale=False,
do_convert_rgb=True
).pixel_values
else:
clip_image = tensors_0_1
@@ -706,7 +724,41 @@ class CustomAdapter(torch.nn.Module):
)
return prompt_embeds
elif self.adapter_type == 'redux':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if self.config.quad_image:
# get the outputs of the quat
chunks = img_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
img_embeds = chunk_sum / quad_count
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype)))
prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2)
return prompt_embeds
else:
return prompt_embeds
@@ -945,6 +997,8 @@ class CustomAdapter(torch.nn.Module):
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'single_value':
yield from self.single_value_adapter.parameters(recurse)
elif self.config.type == 'redux':
yield from self.redux_adapter.parameters(recurse)
else:
raise NotImplementedError

26
toolkit/models/redux.py Normal file
View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = nn.Linear(
txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
x = self.redux_up(sigclip_embeds)
x = torch.nn.functional.silu(x)
projected_x = self.redux_down(x)
return projected_x