Added some additional experimental things to the vision direct encoder

This commit is contained in:
Jaret Burkett
2024-10-10 19:42:26 +00:00
parent ab22674980
commit 3922981996
4 changed files with 101 additions and 23 deletions

View File

@@ -407,7 +407,7 @@ class CustomAdapter(torch.nn.Module):
if 'vd_adapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
if 'dvadapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=False)
if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
@@ -732,8 +732,7 @@ class CustomAdapter(torch.nn.Module):
def train(self, mode: bool = True):
if self.config.train_image_encoder:
self.vision_encoder.train(mode)
else:
super().train(mode)
super().train(mode)
def trigger_pre_te(
self,
@@ -879,7 +878,10 @@ class CustomAdapter(torch.nn.Module):
elif self.config.clip_layer == 'last_hidden_state':
clip_image_embeds = clip_output.hidden_states[-1]
else:
clip_image_embeds = clip_output.image_embeds
if hasattr(clip_output, 'image_embeds'):
clip_image_embeds = clip_output.image_embeds
elif hasattr(clip_output, 'pooler_output'):
clip_image_embeds = clip_output.pooler_output
# TODO should we always norm image embeds?
# get norm embeddings
l2_norm = torch.norm(clip_image_embeds, p=2)
@@ -931,8 +933,12 @@ class CustomAdapter(torch.nn.Module):
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
if self.config.num_tokens:
if self.vd_adapter.resampler is not None:
yield from self.vd_adapter.resampler.parameters(recurse)
if self.vd_adapter.pool is not None:
yield from self.vd_adapter.pool.parameters(recurse)
if self.vd_adapter.sparse_autoencoder is not None:
yield from self.vd_adapter.sparse_autoencoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder: