import torch from backend.args import args if args.xformers: import xformers import xformers.ops def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads q, k, v = map( lambda t: t.unsqueeze(3) .reshape(b, -1, heads, dim_head) .permute(0, 2, 1, 3) .reshape(b * heads, -1, dim_head) .contiguous(), (q, k, v), ) if mask is not None: pad = 8 - q.shape[1] % 8 mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device) mask_out[:, :, :mask.shape[-1]] = mask mask = mask_out[:, :, :mask.shape[-1]] out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) out = ( out.unsqueeze(0) .reshape(b, heads, -1, dim_head) .permute(0, 2, 1, 3) .reshape(b, -1, heads * dim_head) ) return out def attention_pytorch(q, k, v, heads, mask=None): b, _, dim_head = q.shape dim_head //= heads q, k, v = map( lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v), ) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) ) return out attention_function = attention_pytorch if args.xformers: print("Using xformers cross attention") attention_function = attention_xformers else: print("Using pytorch cross attention") attention_function = attention_pytorch class AttentionProcessorForge: def __call__(self, attn, hidden_states, encoder_hidden_states, attention_mask=None, temb=None, *args, **kwargs): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) hidden_states = attention_function(query, key, value, heads=attn.heads, mask=attention_mask) hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states