WIP on clip vision encoder

This commit is contained in:
Jaret Burkett
2024-03-13 07:24:08 -06:00
parent d87b49882c
commit 72de68d8aa
4 changed files with 164 additions and 73 deletions

View File

@@ -740,14 +740,36 @@ class SDTrainer(BaseSDTrainProcess):
embeds_to_use = conditional_embeds.clone().detach()
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
prompt_list = batch.get_caption_list()
class_name = ''
triggers = ['[trigger]', '[name]']
remove_tokens = []
if self.embed_config is not None:
triggers.append(self.embed_config.trigger)
for i in range(1, self.embed_config.tokens):
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
if self.embed_config.trigger_class_name is not None:
class_name = self.embed_config.trigger_class_name
if self.adapter is not None:
triggers.append(self.adapter_config.trigger)
for i in range(1, self.adapter_config.num_tokens):
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
if self.adapter_config.trigger_class_name is not None:
class_name = self.adapter_config.trigger_class_name
for idx, prompt in enumerate(prompt_list):
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
for remove_token in remove_tokens:
prompt = prompt.replace(remove_token, '')
for trigger in triggers:
prompt = prompt.replace(trigger, class_name)
prompt_list[idx] = prompt
embeds_to_use = self.sd.encode_prompt(
prompt,
prompt_list,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype).detach()
@@ -1030,7 +1052,8 @@ class SDTrainer(BaseSDTrainProcess):
if has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True
is_training=True,
has_been_preprocessed=True
)
else:
# just do a blank one
@@ -1039,7 +1062,9 @@ class SDTrainer(BaseSDTrainProcess):
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
),
is_training=True
is_training=True,
has_been_preprocessed=True,
drop=True
)
# it will be injected into the tokenizer when called
self.adapter(conditional_clip_embeds)