mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 08:49:14 +00:00
Added IP adapter training. Not functioning correctly yet
This commit is contained in:
@@ -13,8 +13,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit import train_tools
|
||||
@@ -123,7 +124,7 @@ class StableDiffusion:
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['T2IAdapter', None] = None
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
@@ -302,12 +303,13 @@ class StableDiffusion:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter:
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
@@ -358,11 +360,19 @@ class StableDiffusion:
|
||||
gen_config = image_configs[i]
|
||||
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
if self.adapter is not None and gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
# not sure why this is double??
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.PILToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
@@ -382,6 +392,14 @@ class StableDiffusion:
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
@@ -931,10 +949,18 @@ class StableDiffusion:
|
||||
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
}
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
requires_grad = self.adapter.adapter_modules.training
|
||||
adapter_device = self.unet.device
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
|
||||
'device': adapter_device,
|
||||
'requires_grad': requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
|
||||
Reference in New Issue
Block a user