mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added te aug adapter
This commit is contained in:
@@ -930,6 +930,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
|
||||
grad_on_text_encoder = True
|
||||
|
||||
if self.adapter_config.type == 'te_augmenter':
|
||||
grad_on_text_encoder = True
|
||||
|
||||
# have a blank network so we can wrap it in a context and set multipliers without checking every time
|
||||
if self.network is not None:
|
||||
network = self.network
|
||||
@@ -1045,6 +1048,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
unconditional_embeds = None
|
||||
if grad_on_text_encoder:
|
||||
with torch.set_grad_enabled(True):
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
@@ -1053,6 +1058,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
dtype=dtype)
|
||||
|
||||
if self.train_config.do_cfg:
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
# todo only do one and repeat it
|
||||
unconditional_embeds = self.sd.encode_prompt(
|
||||
self.batch_negative_prompt,
|
||||
@@ -1061,6 +1068,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
else:
|
||||
with torch.set_grad_enabled(False):
|
||||
# make sure it is in eval mode
|
||||
@@ -1069,6 +1078,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
te.eval()
|
||||
else:
|
||||
self.sd.text_encoder.eval()
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.sd.encode_prompt(
|
||||
conditioned_prompts, prompt_2,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
@@ -1076,12 +1087,16 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if self.train_config.do_cfg:
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.sd.encode_prompt(
|
||||
self.batch_negative_prompt,
|
||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# detach the embeddings
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
|
||||
@@ -477,7 +477,8 @@ class TrainVAEProcess(BaseTrainProcess):
|
||||
if self.use_critic:
|
||||
loss_string += f" crD: {critic_d_loss:.2e}"
|
||||
|
||||
if self.optimizer_type.startswith('dadaptation'):
|
||||
if self.optimizer_type.startswith('dadaptation') or \
|
||||
self.optimizer_type.lower().startswith('prodigy'):
|
||||
learning_rate = (
|
||||
optimizer.param_groups[0]["d"] *
|
||||
optimizer.param_groups[0]["lr"]
|
||||
|
||||
@@ -10,6 +10,7 @@ from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
@@ -60,6 +61,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
self.flag_word = "fla9wor0"
|
||||
self.is_unconditional_run = False
|
||||
|
||||
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
|
||||
|
||||
@@ -83,6 +85,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.te: Union[T5EncoderModel, CLIPTextModel] = None
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.te_adapter: TEAdapter = None
|
||||
self.te_augmenter: TEAugAdapter = None
|
||||
self.vd_adapter: VisionDirectAdapter = None
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -149,6 +152,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||
|
||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
|
||||
else:
|
||||
@@ -269,9 +274,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
preprocessor_input_size = self.vision_encoder.config.image_size * 2
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
if 'height' in self.image_processor.size:
|
||||
self.image_processor.size['height'] = preprocessor_input_size
|
||||
self.image_processor.size['width'] = preprocessor_input_size
|
||||
elif hasattr(self.image_processor, 'crop_size'):
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.image_processor.config
|
||||
@@ -340,6 +349,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'te_adapter' in state_dict:
|
||||
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
|
||||
|
||||
if 'te_augmenter' in state_dict:
|
||||
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
|
||||
|
||||
if 'vd_adapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
|
||||
if 'dvadapter' in state_dict:
|
||||
@@ -378,6 +390,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["te_augmenter"] = self.te_augmenter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
state_dict["dvadapter"] = self.vd_adapter.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
@@ -647,7 +664,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if tensors_0_1 is None:
|
||||
tensors_0_1 = self.get_empty_clip_image(1)
|
||||
has_been_preprocessed = True
|
||||
@@ -675,7 +692,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image = tensors_0_1
|
||||
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
@@ -730,7 +747,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module(img_embeds)
|
||||
if self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -754,8 +771,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
if self.adapter_type == 'te_augmenter':
|
||||
clip_image_embeds = self.te_augmenter(clip_image_embeds)
|
||||
|
||||
# save them to the conditional and unconditional
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
try:
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
except ValueError:
|
||||
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
@@ -781,6 +804,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
253
toolkit/models/te_aug_adapter.py
Normal file
253
toolkit/models/te_aug_adapter.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import weakref
|
||||
from typing import Union, TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
|
||||
|
||||
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.resampler import Resampler
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
from toolkit.custom_adapter import CustomAdapter
|
||||
|
||||
|
||||
class TEAugAdapterCLIPAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'):
|
||||
super().__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.attn_module_ref: weakref.ref = weakref.ref(attn_module)
|
||||
self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
|
||||
self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
|
||||
# copy the weights from the original module
|
||||
self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01
|
||||
self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01
|
||||
#reset the bias
|
||||
self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001
|
||||
self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001
|
||||
|
||||
self.zipper = ZipperModule(
|
||||
in_size=attn_module.embed_dim,
|
||||
in_tokens=77 * 2,
|
||||
out_size=attn_module.embed_dim,
|
||||
out_tokens=77,
|
||||
hidden_size=attn_module.embed_dim,
|
||||
hidden_tokens=77,
|
||||
)
|
||||
# self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data)
|
||||
# self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data)
|
||||
# #reset the bias
|
||||
# self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data)
|
||||
# self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data)
|
||||
|
||||
# replace the original forward with our forward
|
||||
self.original_forward = attn_module.forward
|
||||
attn_module.forward = self.forward
|
||||
|
||||
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
causal_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
attn_module = self.attn_module_ref()
|
||||
|
||||
bsz, tgt_len, embed_dim = hidden_states.size()
|
||||
|
||||
# get query proj
|
||||
query_states = attn_module.q_proj(hidden_states) * attn_module.scale
|
||||
key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz)
|
||||
value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim)
|
||||
query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
|
||||
src_len = key_states.size(1)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
# apply the causal_attention_mask first
|
||||
if causal_attention_mask is not None:
|
||||
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||
f" {causal_attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask
|
||||
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if output_attentions:
|
||||
# this operation is a bit akward, but it's required to
|
||||
# make sure that attn_weights keeps its gradient.
|
||||
# In order to do so, attn_weights have to reshaped
|
||||
# twice and have to be reused in the following
|
||||
attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len)
|
||||
attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len)
|
||||
else:
|
||||
attn_weights_reshaped = None
|
||||
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref()
|
||||
if self.adapter_ref().is_active and adapter.conditional_embeds is not None:
|
||||
# apply the adapter
|
||||
|
||||
if adapter.is_unconditional_run:
|
||||
embeds = adapter.unconditional_embeds
|
||||
else:
|
||||
embeds = adapter.conditional_embeds
|
||||
# if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well
|
||||
if embeds.size(0) != bsz:
|
||||
embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0)
|
||||
|
||||
key_states_raw = self.k_proj_adapter(embeds)
|
||||
key_states = attn_module._shape(key_states_raw, -1, bsz)
|
||||
value_states_raw = self.v_proj_adapter(embeds)
|
||||
value_states = attn_module._shape(value_states_raw, -1, bsz)
|
||||
key_states = key_states.view(*proj_shape)
|
||||
value_states = value_states.view(*proj_shape)
|
||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
|
||||
attn_output_adapter = torch.bmm(attn_probs, value_states)
|
||||
|
||||
if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
|
||||
f" {attn_output_adapter.size()}"
|
||||
)
|
||||
|
||||
attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
|
||||
attn_output_adapter = attn_output_adapter.transpose(1, 2)
|
||||
attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim)
|
||||
|
||||
attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1))
|
||||
|
||||
# attn_output_adapter = attn_module.out_proj(attn_output_adapter)
|
||||
attn_output = attn_output + attn_output_adapter
|
||||
|
||||
attn_output = attn_module.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights_reshaped
|
||||
|
||||
class TEAugAdapter(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
adapter: 'CustomAdapter',
|
||||
sd: 'StableDiffusion',
|
||||
):
|
||||
super(TEAugAdapter, self).__init__()
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
|
||||
if isinstance(sd.text_encoder, list):
|
||||
raise ValueError("Dual text encoders is not yet supported")
|
||||
|
||||
# dim will come from text encoder
|
||||
# dim = sd.unet.config['cross_attention_dim']
|
||||
text_encoder: CLIPTextModel = sd.text_encoder
|
||||
dim = text_encoder.config.hidden_size
|
||||
|
||||
clip_encoder: CLIPEncoder = text_encoder.text_model.encoder
|
||||
# dim = clip_encoder.layers[-1].self_attn
|
||||
|
||||
if hasattr(adapter.vision_encoder.config, 'hidden_sizes'):
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
|
||||
else:
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_size
|
||||
|
||||
image_encoder_state_dict = adapter.vision_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
in_tokens = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
if adapter.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
|
||||
|
||||
out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens
|
||||
self.image_proj_model = ZipperModule(
|
||||
in_size=embedding_dim,
|
||||
in_tokens=in_tokens,
|
||||
out_size=dim,
|
||||
out_tokens=out_tokens,
|
||||
hidden_size=dim,
|
||||
hidden_tokens=out_tokens,
|
||||
)
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
for idx, layer in enumerate(clip_encoder.layers):
|
||||
name = f"clip_attention.{idx}"
|
||||
attn_procs[name] = TEAugAdapterCLIPAttention(
|
||||
layer.self_attn,
|
||||
self
|
||||
)
|
||||
|
||||
self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values()))
|
||||
|
||||
# make a getter to see if is active
|
||||
@property
|
||||
def is_active(self):
|
||||
return self.adapter_ref().is_active
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
# # apply the adapter
|
||||
input = self.image_proj_model(input)
|
||||
# self.embeds = input
|
||||
return input
|
||||
@@ -585,11 +585,17 @@ class StableDiffusion:
|
||||
)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = True
|
||||
unconditional_embeds = self.encode_prompt(
|
||||
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
|
||||
)
|
||||
if isinstance(self.adapter, CustomAdapter):
|
||||
self.adapter.is_unconditional_run = False
|
||||
|
||||
# allow any manipulations to take place to embeddings
|
||||
gen_config.post_process_embeddings(
|
||||
|
||||
Reference in New Issue
Block a user