mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Added comparitive loss when training clip encoder. Allow selecting clip layer. on ip adapter. Improvements to prior prediction
This commit is contained in:
@@ -438,7 +438,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
is_unconditional=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
if self.adapter_type == 'ilora':
|
||||
return prompt_embeds
|
||||
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds.clone()
|
||||
@@ -503,7 +506,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.token_mask
|
||||
)
|
||||
return prompt_embeds
|
||||
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -535,22 +538,96 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
|
||||
return prompt_embeds
|
||||
else:
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora':
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
if self.config.quad_image:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
to_cat = []
|
||||
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||
if i < quad_count:
|
||||
to_cat.append(ci)
|
||||
else:
|
||||
break
|
||||
|
||||
clip_image = torch.cat(to_cat, dim=0).detach()
|
||||
|
||||
if self.adapter_type == 'ilora':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = img_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
img_embeds = chunk_sum / quad_count
|
||||
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
|
||||
Reference in New Issue
Block a user