Added a clip vision adapter trainer. Only works for sd15 for now

This commit is contained in:
Jaret Burkett
2023-12-24 13:26:04 -07:00
parent 0f8daa5612
commit 05ae95ca89
6 changed files with 586 additions and 20 deletions

View File

@@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint
from tqdm import tqdm
from torchvision.transforms import Resize, transforms
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
@@ -472,7 +473,7 @@ class StableDiffusion:
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter):
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
transform = transforms.Compose([
transforms.ToTensor(),
])
@@ -483,6 +484,12 @@ class StableDiffusion:
torch.manual_seed(gen_config.seed)
torch.cuda.manual_seed(gen_config.seed)
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
and gen_config.adapter_image_path is not None:
# run through the adapter to saturate the embeds
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
self.adapter(conditional_clip_embeds)
# encode the prompt ourselves so we can do fun stuff with embeddings
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
@@ -496,8 +503,8 @@ class StableDiffusion:
unconditional_embeds,
)
if self.adapter is not None and isinstance(self.adapter,
IPAdapter) and gen_config.adapter_image_path is not None:
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
@@ -1445,6 +1452,9 @@ class StableDiffusion:
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
elif isinstance(self.adapter, ClipVisionAdapter):
requires_grad = self.adapter.embedder.training
adapter_device = self.adapter.device
else:
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
self.device_state['adapter'] = {