mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added support for training flux redux adapters
This commit is contained in:
@@ -513,12 +513,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# move it back
|
||||
self.adapter = self.adapter.to(orig_device, dtype=orig_dtype)
|
||||
else:
|
||||
direct_save = False
|
||||
if self.adapter_config.train_only_image_encoder:
|
||||
direct_save = True
|
||||
if self.adapter_config.type == 'redux':
|
||||
direct_save = True
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
direct_save=self.adapter_config.train_only_image_encoder
|
||||
direct_save=direct_save
|
||||
)
|
||||
else:
|
||||
if self.save_config.save_format == "diffusers":
|
||||
|
||||
@@ -14,6 +14,7 @@ from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.models.redux import ReduxImageEncoder
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
@@ -93,6 +94,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.te_augmenter: TEAugAdapter = None
|
||||
self.vd_adapter: VisionDirectAdapter = None
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.redux_adapter: ReduxImageEncoder = None
|
||||
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -201,6 +204,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
|
||||
elif self.adapter_type == 'single_value':
|
||||
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
|
||||
elif self.adapter_type == 'redux':
|
||||
self.redux_adapter = ReduxImageEncoder(1152, 4096, self.device, torch_dtype)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -423,6 +428,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if 'redux_up' in state_dict:
|
||||
# state dict is seperated. so recombine it
|
||||
new_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
for k2, v2 in v.items():
|
||||
new_dict[k + '.' + k2] = v2
|
||||
self.redux_adapter.load_state_dict(new_dict, strict=True)
|
||||
|
||||
pass
|
||||
|
||||
@@ -466,6 +478,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["ilora"] = self.ilora_module.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'redux':
|
||||
d = self.redux_adapter.state_dict()
|
||||
for k, v in d.items():
|
||||
state_dict[k] = v
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -482,7 +499,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'redux':
|
||||
return prompt
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
# todo allow for training
|
||||
@@ -604,7 +621,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'redux':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds.clone()
|
||||
@@ -626,6 +643,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
do_convert_rgb=True
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
@@ -706,7 +724,41 @@ class CustomAdapter(torch.nn.Module):
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
elif self.adapter_type == 'redux':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = img_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
img_embeds = chunk_sum / quad_count
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
img_embeds = self.redux_adapter(img_embeds.to(self.device, get_torch_dtype(self.sd_ref().dtype)))
|
||||
|
||||
prompt_embeds.text_embeds = torch.cat((prompt_embeds.text_embeds, img_embeds), dim=-2)
|
||||
return prompt_embeds
|
||||
else:
|
||||
return prompt_embeds
|
||||
|
||||
@@ -945,6 +997,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'single_value':
|
||||
yield from self.single_value_adapter.parameters(recurse)
|
||||
elif self.config.type == 'redux':
|
||||
yield from self.redux_adapter.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
26
toolkit/models/redux.py
Normal file
26
toolkit/models/redux.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ReduxImageEncoder(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
redux_dim: int = 1152,
|
||||
txt_in_features: int = 4096,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.redux_dim = redux_dim
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
|
||||
self.redux_down = nn.Linear(
|
||||
txt_in_features * 3, txt_in_features, dtype=dtype)
|
||||
|
||||
def forward(self, sigclip_embeds) -> torch.Tensor:
|
||||
x = self.redux_up(sigclip_embeds)
|
||||
x = torch.nn.functional.silu(x)
|
||||
|
||||
projected_x = self.redux_down(x)
|
||||
return projected_x
|
||||
Reference in New Issue
Block a user