Added support for training vision direct weight adapters

This commit is contained in:
Jaret Burkett
2024-09-05 10:11:44 -06:00
parent 5c8fcc8a4e
commit 3a1f464132
3 changed files with 78 additions and 16 deletions

View File

@@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module):
if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict:
@@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module):
for attn_processor in self.te_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
elif self.config.type == 'vision_direct':
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
if self.config.train_scaler:
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
yield self.vd_adapter.block_scaler
else:
for attn_processor in self.vd_adapter.adapter_modules:
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:
@@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module):
additional[k] = v
additional['clip_layer'] = self.config.clip_layer
additional['image_encoder_arch'] = self.config.head_dim
return additional
return additional
def post_weight_update(self):
# do any kind of updates after the weight update
if self.config.type == 'vision_direct':
self.vd_adapter.post_weight_update()
pass