Bug fixes. allow for random negative prompts

This commit is contained in:
Jaret Burkett
2024-02-21 04:51:52 -07:00
parent 2478554c95
commit 49c41e6a5f
8 changed files with 166 additions and 6 deletions

View File

@@ -322,7 +322,7 @@ class IPAdapter(torch.nn.Module):
elif adapter_config.type == 'ip+':
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
self.image_encoder.config.hidden_sizes[-1]
image_encoder_state_dict = self.image_encoder.state_dict()
@@ -340,6 +340,10 @@ class IPAdapter(torch.nn.Module):
dim = 4096
output_dim = 4096
if self.config.image_encoder_arch.startswith('convnext'):
in_tokens = 16 * 16
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
# ip-adapter-plus
image_proj_model = Resampler(
dim=dim,
@@ -406,6 +410,8 @@ class IPAdapter(torch.nn.Module):
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
attn_processor_names = []
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
sd.unet.config['cross_attention_dim']
@@ -446,6 +452,9 @@ class IPAdapter(torch.nn.Module):
}
attn_procs[name].load_state_dict(weights)
attn_processor_names.append(name)
print(f"Attn Processors")
print(attn_processor_names)
if self.sd_ref().is_pixart:
# we have to set them ourselves
transformer: Transformer2DModel = sd.unet
@@ -690,6 +699,12 @@ class IPAdapter(torch.nn.Module):
else:
clip_image_embeds = clip_output.image_embeds
if self.config.image_encoder_arch.startswith('convnext'):
# flatten the width height layers to make the token space
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
# rearrange to (batch, tokens, size)
clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)