mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Adjustments to the clip preprocessor. Allow merging in new weights for ip adapters so you can change the arcitecture while maintaining as much data as possible
This commit is contained in:
@@ -230,7 +230,7 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 3
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 4
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
@@ -240,7 +240,6 @@ class IPAdapter(torch.nn.Module):
|
||||
self.preprocessor = CLIPImagePreProcessor(
|
||||
input_size=preprocessor_input_size,
|
||||
clip_input_size=self.image_encoder.config.image_size,
|
||||
downscale_factor=6
|
||||
)
|
||||
|
||||
self.input_size = self.clip_image_processor.size['shortest_edge']
|
||||
@@ -454,13 +453,68 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.preprocessor is not None:
|
||||
yield from self.preprocessor.parameters(recurse)
|
||||
|
||||
def merge_in_weights(self, state_dict: Mapping[str, Any]):
|
||||
# merge in img_proj weights
|
||||
current_img_proj_state_dict = self.image_proj_model.state_dict()
|
||||
for key, value in state_dict["image_proj"].items():
|
||||
if key in current_img_proj_state_dict:
|
||||
current_shape = current_img_proj_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
else:
|
||||
current_img_proj_state_dict[key] = value
|
||||
self.image_proj_model.load_state_dict(current_img_proj_state_dict)
|
||||
|
||||
# merge in ip adapter weights
|
||||
current_ip_adapter_state_dict = self.adapter_modules.state_dict()
|
||||
for key, value in state_dict["ip_adapter"].items():
|
||||
if key in current_ip_adapter_state_dict:
|
||||
current_shape = current_ip_adapter_state_dict[key].shape
|
||||
new_shape = value.shape
|
||||
if current_shape != new_shape:
|
||||
# merge in what we can and leave the other values as they are
|
||||
if len(current_shape) == 1:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0]] = value
|
||||
elif len(current_shape) == 2:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1]] = value
|
||||
elif len(current_shape) == 3:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2]] = value
|
||||
elif len(current_shape) == 4:
|
||||
current_ip_adapter_state_dict[key][:new_shape[0], :new_shape[1], :new_shape[2],
|
||||
:new_shape[3]] = value
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
else:
|
||||
current_ip_adapter_state_dict[key] = value
|
||||
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
try:
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("could not load ip adapter weights, trying to merge in weights")
|
||||
self.merge_in_weights(state_dict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
if self.config.image_encoder_arch == 'clip+' and 'preprocessor' in state_dict:
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
|
||||
Reference in New Issue
Block a user