Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction

This commit is contained in:
Jaret Burkett
2024-02-05 07:40:03 -07:00
parent 177c7130ec
commit e18e0cb5f8
7 changed files with 227 additions and 65 deletions

View File

@@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module):
is_unconditional=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
if self.adapter_type == 'ilora':
return prompt_embeds
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds.clone()
@@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module):
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
elif self.adapter_type == 'clip_fusion':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
if self.adapter_type == 'ilora':
self.ilora_module.img_embeds = img_embeds
return prompt_embeds
else:
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
else:
raise NotImplementedError
def trigger_pre_te(
self,
tensors_0_1: torch.Tensor,
is_training=False,
has_been_preprocessed=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'ilora':
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
clip_image = torch.cat(to_cat, dim=0).detach()
if self.adapter_type == 'ilora':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if self.config.quad_image:
# get the outputs of the quat
chunks = img_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
img_embeds = chunk_sum / quad_count
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
self.ilora_module.img_embeds = img_embeds
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)