Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible

This commit is contained in:
Jaret Burkett
2024-01-06 11:56:53 -07:00
parent 645b27f97a
commit b767d29b3c
3 changed files with 99 additions and 36 deletions

View File

@@ -230,7 +230,7 @@ class IPAdapter(torch.nn.Module):
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 3
preprocessor_input_size = self.image_encoder.config.image_size * 4
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
@@ -240,7 +240,6 @@ class IPAdapter(torch.nn.Module):
self.preprocessor = CLIPImagePreProcessor(
input_size=preprocessor_input_size,
clip_input_size=self.image_encoder.config.image_size,
downscale_factor=6
)
self.input_size = self.clip_image_processor.size['shortest_edge']
@@ -454,13 +453,68 @@ class IPAdapter(torch.nn.Module):
if self.preprocessor is not None:
yield from self.preprocessor.parameters(recurse)
def merge_in_weights(self, state_dict: Mapping[str, Any]):
# merge in img_proj weights
current_img_proj_state_dict = self.image_proj_model.state_dict()
for key, value in state_dict["image_proj"].items():
if key in current_img_proj_state_dict:
current_shape = current_img_proj_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_img_proj_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_img_proj_state_dict[key] = value
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
# merge in ip adapter weights
current_ip_adapter_state_dict = self.adapter_modules.state_dict()
for key, value in state_dict["ip_adapter"].items():
if key in current_ip_adapter_state_dict:
current_shape = current_ip_adapter_state_dict[key].shape
new_shape = value.shape
if current_shape != new_shape:
# merge in what we can and leave the other values as they are
if len(current_shape) == 1:
current_ip_adapter_state_dict[key][:new_shape[0]] = value
elif len(current_shape) == 2:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
elif len(current_shape) == 3:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
elif len(current_shape) == 4:
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
:new_shape[3]] = value
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
else:
current_ip_adapter_state_dict[key] = value
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
try:
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
except Exception as e:
print(e)
print("could not load ip adapter weights, trying to merge in weights")
self.merge_in_weights(state_dict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
def enable_gradient_checkpointing(self):