mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Added support for training vision direct weight adapters
This commit is contained in:
@@ -402,7 +402,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'sv_adapter' in state_dict:
|
||||
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
|
||||
if 'vision_encoder' in state_dict:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
@@ -881,10 +881,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
for attn_processor in self.te_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
elif self.config.type == 'vision_direct':
|
||||
for attn_processor in self.vd_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
if self.config.train_scaler:
|
||||
# only yield the self.block_scaler = torch.nn.Parameter(torch.tensor([1.0] * num_modules)
|
||||
yield self.vd_adapter.block_scaler
|
||||
else:
|
||||
for attn_processor in self.vd_adapter.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
@@ -908,4 +912,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
additional[k] = v
|
||||
additional['clip_layer'] = self.config.clip_layer
|
||||
additional['image_encoder_arch'] = self.config.head_dim
|
||||
return additional
|
||||
return additional
|
||||
|
||||
def post_weight_update(self):
|
||||
# do any kind of updates after the weight update
|
||||
if self.config.type == 'vision_direct':
|
||||
self.vd_adapter.post_weight_update()
|
||||
pass
|
||||
Reference in New Issue
Block a user