mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-27 01:39:20 +00:00
Added a clip vision adapter trainer. Only works for sd15 for now
This commit is contained in:
@@ -18,6 +18,7 @@ from torch.utils.checkpoint import checkpoint
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.clip_vision_adapter import ClipVisionAdapter
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
@@ -472,7 +473,7 @@ class StableDiffusion:
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
if isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ClipVisionAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
@@ -483,6 +484,12 @@ class StableDiffusion:
|
||||
torch.manual_seed(gen_config.seed)
|
||||
torch.cuda.manual_seed(gen_config.seed)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
# run through the adapter to saturate the embeds
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
self.adapter(conditional_clip_embeds)
|
||||
|
||||
# encode the prompt ourselves so we can do fun stuff with embeddings
|
||||
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
|
||||
|
||||
@@ -496,8 +503,8 @@ class StableDiffusion:
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter,
|
||||
IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) \
|
||||
and gen_config.adapter_image_path is not None:
|
||||
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
@@ -1445,6 +1452,9 @@ class StableDiffusion:
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
elif isinstance(self.adapter, ClipVisionAdapter):
|
||||
requires_grad = self.adapter.embedder.training
|
||||
adapter_device = self.adapter.device
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
||||
self.device_state['adapter'] = {
|
||||
|
||||
Reference in New Issue
Block a user