mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
WIP on clip vision encoder
This commit is contained in:
@@ -740,14 +740,36 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
embeds_to_use = conditional_embeds.clone().detach()
|
||||
# handle clip vision adapter by removing triggers from prompt and replacing with the class name
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter):
|
||||
if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None:
|
||||
prompt_list = batch.get_caption_list()
|
||||
class_name = ''
|
||||
|
||||
triggers = ['[trigger]', '[name]']
|
||||
remove_tokens = []
|
||||
|
||||
if self.embed_config is not None:
|
||||
triggers.append(self.embed_config.trigger)
|
||||
for i in range(1, self.embed_config.tokens):
|
||||
remove_tokens.append(f"{self.embed_config.trigger}_{i}")
|
||||
if self.embed_config.trigger_class_name is not None:
|
||||
class_name = self.embed_config.trigger_class_name
|
||||
|
||||
if self.adapter is not None:
|
||||
triggers.append(self.adapter_config.trigger)
|
||||
for i in range(1, self.adapter_config.num_tokens):
|
||||
remove_tokens.append(f"{self.adapter_config.trigger}_{i}")
|
||||
if self.adapter_config.trigger_class_name is not None:
|
||||
class_name = self.adapter_config.trigger_class_name
|
||||
|
||||
for idx, prompt in enumerate(prompt_list):
|
||||
prompt = self.adapter.inject_trigger_class_name_into_prompt(prompt)
|
||||
for remove_token in remove_tokens:
|
||||
prompt = prompt.replace(remove_token, '')
|
||||
for trigger in triggers:
|
||||
prompt = prompt.replace(trigger, class_name)
|
||||
prompt_list[idx] = prompt
|
||||
|
||||
embeds_to_use = self.sd.encode_prompt(
|
||||
prompt,
|
||||
prompt_list,
|
||||
long_prompts=self.do_long_prompts).to(
|
||||
self.device_torch,
|
||||
dtype=dtype).detach()
|
||||
@@ -1030,7 +1052,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
)
|
||||
else:
|
||||
# just do a blank one
|
||||
@@ -1039,7 +1062,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
),
|
||||
is_training=True
|
||||
is_training=True,
|
||||
has_been_preprocessed=True,
|
||||
drop=True
|
||||
)
|
||||
# it will be injected into the tokenizer when called
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
Reference in New Issue
Block a user