mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Bug fixes. allow for random negative prompts
This commit is contained in:
@@ -322,7 +322,7 @@ class IPAdapter(torch.nn.Module):
|
||||
elif adapter_config.type == 'ip+':
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
@@ -340,6 +340,10 @@ class IPAdapter(torch.nn.Module):
|
||||
dim = 4096
|
||||
output_dim = 4096
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
in_tokens = 16 * 16
|
||||
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=dim,
|
||||
@@ -406,6 +410,8 @@ class IPAdapter(torch.nn.Module):
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
attn_processor_names = []
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
@@ -446,6 +452,9 @@ class IPAdapter(torch.nn.Module):
|
||||
}
|
||||
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
attn_processor_names.append(name)
|
||||
print(f"Attn Processors")
|
||||
print(attn_processor_names)
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
@@ -690,6 +699,12 @@ class IPAdapter(torch.nn.Module):
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
|
||||
if self.config.image_encoder_arch.startswith('convnext'):
|
||||
# flatten the width height layers to make the token space
|
||||
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
|
||||
# rearrange to (batch, tokens, size)
|
||||
clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
|
||||
Reference in New Issue
Block a user