mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Added some additional experimental things to the vision direct encoder
This commit is contained in:
@@ -407,7 +407,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'vd_adapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
|
||||
if 'dvadapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False)
|
||||
|
||||
if 'sv_adapter' in state_dict:
|
||||
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
|
||||
@@ -732,8 +732,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def train(self, mode: bool = True):
|
||||
if self.config.train_image_encoder:
|
||||
self.vision_encoder.train(mode)
|
||||
else:
|
||||
super().train(mode)
|
||||
super().train(mode)
|
||||
|
||||
def trigger_pre_te(
|
||||
self,
|
||||
@@ -879,7 +878,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.config.clip_layer == 'last_hidden_state':
|
||||
clip_image_embeds = clip_output.hidden_states[-1]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
if hasattr(clip_output, 'image_embeds'):
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
elif hasattr(clip_output, 'pooler_output'):
|
||||
clip_image_embeds = clip_output.pooler_output
|
||||
# TODO should we always norm image embeds?
|
||||
# get norm embeddings
|
||||
l2_norm = torch.norm(clip_image_embeds, p=2)
|
||||
@@ -931,8 +933,12 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
if self.config.num_tokens:
|
||||
if self.vd_adapter.resampler is not None:
|
||||
yield from self.vd_adapter.resampler.parameters(recurse)
|
||||
if self.vd_adapter.pool is not None:
|
||||
yield from self.vd_adapter.pool.parameters(recurse)
|
||||
if self.vd_adapter.sparse_autoencoder is not None:
|
||||
yield from self.vd_adapter.sparse_autoencoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
|
||||
Reference in New Issue
Block a user