Added IP adapter training. Not functioning correctly yet

This commit is contained in:
Jaret Burkett
2023-09-24 02:39:43 -06:00
parent 19255cdc7c
commit 830e87cb87
9 changed files with 336 additions and 53 deletions

View File

@@ -13,8 +13,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
from tqdm import tqdm
from torchvision.transforms import Resize
from torchvision.transforms import Resize, transforms
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
from toolkit import train_tools
@@ -123,7 +124,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['T2IAdapter', None] = None
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
@@ -302,12 +303,13 @@ class StableDiffusion:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter:
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
if self.adapter is not None:
if isinstance(self.adapter, T2IAdapter):
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
# TODO add clip skip
if self.is_xl:
@@ -358,11 +360,19 @@ class StableDiffusion:
gen_config = image_configs[i]
extra = {}
if gen_config.adapter_image_path is not None:
if self.adapter is not None and gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0
if isinstance(self.adapter, T2IAdapter):
# not sure why this is double??
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0
if isinstance(self.adapter, IPAdapter):
transform = transforms.Compose([
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.PILToTensor(),
])
validation_image = transform(validation_image)
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
@@ -382,6 +392,14 @@ class StableDiffusion:
unconditional_embeds,
)
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
@@ -931,10 +949,18 @@ class StableDiffusion:
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
}
if self.adapter is not None:
if isinstance(self.adapter, IPAdapter):
requires_grad = self.adapter.adapter_modules.training
adapter_device = self.unet.device
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
else:
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
self.device_state['adapter'] = {
'training': self.adapter.training,
'device': self.adapter.device,
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
'device': adapter_device,
'requires_grad': requires_grad,
}
def restore_device_state(self):