diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index fea1e81f..beacb203 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -269,16 +269,7 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: is_active = self.adapter_ref().is_active - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -297,7 +288,44 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): if attn.norm_k is not None: key = attn.norm_k(key) - # will be none if disabled + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # begin ip adapter if not is_active: ip_hidden_states = None else: @@ -309,47 +337,6 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): raise ValueError("Unconditional is None but should not be") ip_hidden_states = torch.cat([self.adapter_ref().last_unconditional, ip_hidden_states], dim=0) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) - from diffusers.models.embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # do ip adapter - # will be none if disabled if ip_hidden_states is not None: # apply scaler if self.train_scaler: @@ -365,8 +352,6 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False ) @@ -376,26 +361,23 @@ class CustomIPFluxAttnProcessor2_0(torch.nn.Module): scale = self.scale hidden_states = hidden_states + scale * ip_hidden_states + # end ip adapter + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states, encoder_hidden_states + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + else: + return hidden_states # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py class IPAdapter(torch.nn.Module): @@ -659,9 +641,9 @@ class IPAdapter(torch.nn.Module): for i, module in transformer.transformer_blocks.named_children(): attn_processor_keys.append(f"transformer_blocks.{i}.attn") - # single transformer blocks do not have cross attn - # for i, module in transformer.single_transformer_blocks.named_children(): - # attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") else: attn_processor_keys = list(sd.unet.attn_processors.keys()) @@ -695,7 +677,7 @@ class IPAdapter(torch.nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] - elif name.startswith("transformer"): + elif name.startswith("transformer") or name.startswith("single_transformer"): if is_flux: hidden_size = 3072 else: @@ -773,11 +755,20 @@ class IPAdapter(torch.nn.Module): transformer: FluxTransformer2DModel = sd.unet for i, module in transformer.transformer_blocks.named_children(): module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + self.adapter_modules = torch.nn.ModuleList( [ transformer.transformer_blocks[i].attn.processor for i in range(len(transformer.transformer_blocks)) - ]) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) else: sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index 02549242..c180c18b 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -323,17 +323,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: - is_active = self.adapter_ref().is_active - input_ndim = hidden_states.ndim - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - context_input_ndim = encoder_hidden_states.ndim - if context_input_ndim == 4: - batch_size, channel, height, width = encoder_hidden_states.shape - encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size = encoder_hidden_states.shape[0] + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) @@ -352,36 +342,34 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): if attn.norm_k is not None: key = attn.norm_k(key) - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: - # YiYi to-do: update uising apply_rotary_emb - # from ..embeddings import apply_rotary_emb - # query = apply_rotary_emb(query, image_rotary_emb) - # key = apply_rotary_emb(key, image_rotary_emb) from diffusers.models.embeddings import apply_rotary_emb query = apply_rotary_emb(query, image_rotary_emb) @@ -391,8 +379,7 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # do ip adapter - # will be none if disabled + # begin ip adapter if self.is_active and self.conditional_embeds is not None: adapter_hidden_states = self.conditional_embeds if adapter_hidden_states.shape[0] < batch_size: @@ -413,8 +400,6 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 vd_hidden_states = F.scaled_dot_product_attention( query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False ) @@ -424,25 +409,21 @@ class CustomFluxVDAttnProcessor2_0(torch.nn.Module): hidden_states = hidden_states + self.scale * vd_hidden_states + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - if context_input_ndim == 4: - encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - return hidden_states, encoder_hidden_states + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + return hidden_states, encoder_hidden_states + else: + return hidden_states class VisionDirectAdapter(torch.nn.Module): def __init__( @@ -482,9 +463,9 @@ class VisionDirectAdapter(torch.nn.Module): for i, module in transformer.transformer_blocks.named_children(): attn_processor_keys.append(f"transformer_blocks.{i}.attn") - # single transformer blocks do not have cross attn - # for i, module in transformer.single_transformer_blocks.named_children(): - # attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") else: attn_processor_keys = list(sd.unet.attn_processors.keys()) @@ -501,7 +482,7 @@ class VisionDirectAdapter(torch.nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] - elif name.startswith("transformer"): + elif name.startswith("transformer") or name.startswith("single_transformer"): if is_flux: hidden_size = 3072 else: @@ -596,23 +577,32 @@ class VisionDirectAdapter(torch.nn.Module): transformer: FluxTransformer2DModel = sd.unet for i, module in transformer.transformer_blocks.named_children(): module.attn.processor = attn_procs[f"transformer_blocks.{i}.attn"] + + # do single blocks too even though they dont have cross attn + for i, module in transformer.single_transformer_blocks.named_children(): + module.attn.processor = attn_procs[f"single_transformer_blocks.{i}.attn"] + self.adapter_modules = torch.nn.ModuleList( [ transformer.transformer_blocks[i].attn.processor for i in range(len(transformer.transformer_blocks)) - ]) + ] + [ + transformer.single_transformer_blocks[i].attn.processor for i in + range(len(transformer.single_transformer_blocks)) + ] + ) else: sd.unet.set_attn_processor(attn_procs) self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) - # add the mlp layer - self.mlp = MLP( - in_dim=self.token_size, - out_dim=self.token_size, - hidden_dim=self.token_size, - # dropout=0.1, - use_residual=True - ) + # # add the mlp layer + # self.mlp = MLP( + # in_dim=self.token_size, + # out_dim=self.token_size, + # hidden_dim=self.token_size, + # # dropout=0.1, + # use_residual=True + # ) # make a getter to see if is active @property @@ -620,4 +610,5 @@ class VisionDirectAdapter(torch.nn.Module): return self.adapter_ref().is_active def forward(self, input): - return self.mlp(input) + # return self.mlp(input) + return input