mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-02 20:21:22 +00:00
Switched ip adapter dataloader to clip_image paths so the control paths can be used for training assistant adapters while training ip adapters
This commit is contained in:
@@ -567,6 +567,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
network_weight_list = network_weight_list + network_weight_list
|
||||
|
||||
has_adapter_img = batch.control_tensor is not None
|
||||
has_clip_image = batch.clip_image_tensor is not None
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img:
|
||||
raise ValueError("IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ")
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
@@ -604,10 +608,14 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images = adapter_images[:, :in_channels, :, :]
|
||||
else:
|
||||
raise NotImplementedError("Adapter images now must be loaded with dataloader")
|
||||
# not 100% sure what this does. But they do it here
|
||||
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
|
||||
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)
|
||||
# noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5)
|
||||
|
||||
clip_images = None
|
||||
if has_clip_image:
|
||||
with self.timer('get_clip_images'):
|
||||
# todo move this to data loader
|
||||
if batch.clip_image_tensor is not None:
|
||||
clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach()
|
||||
|
||||
|
||||
mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
if batch.mask_tensor is not None:
|
||||
@@ -697,6 +705,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0)
|
||||
else:
|
||||
adapter_images_list = [None for _ in range(batch_size)]
|
||||
if clip_images is not None:
|
||||
clip_images_list = torch.chunk(clip_images, batch_size, dim=0)
|
||||
else:
|
||||
clip_images_list = [None for _ in range(batch_size)]
|
||||
mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0)
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None for _ in range(batch_size)]
|
||||
@@ -710,19 +722,21 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditioned_prompts_list = [prompts_1]
|
||||
imgs_list = [imgs]
|
||||
adapter_images_list = [adapter_images]
|
||||
clip_images_list = [clip_images]
|
||||
mask_multiplier_list = [mask_multiplier]
|
||||
if prompts_2 is None:
|
||||
prompt_2_list = [None]
|
||||
else:
|
||||
prompt_2_list = [prompts_2]
|
||||
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, mask_multiplier, prompt_2 in zip(
|
||||
for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip(
|
||||
noisy_latents_list,
|
||||
noise_list,
|
||||
timesteps_list,
|
||||
conditioned_prompts_list,
|
||||
imgs_list,
|
||||
adapter_images_list,
|
||||
clip_images_list,
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
@@ -766,7 +780,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
if has_adapter_img and (
|
||||
(self.adapter and isinstance(self.adapter, T2IAdapter)) or self.assistant_adapter):
|
||||
with torch.set_grad_enabled(self.adapter is not None):
|
||||
adapter = self.adapter if self.adapter else self.assistant_adapter
|
||||
adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter
|
||||
adapter_multiplier = get_adapter_multiplier()
|
||||
with self.timer('encode_adapter'):
|
||||
down_block_additional_residuals = adapter(adapter_images)
|
||||
@@ -787,20 +801,20 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
if has_adapter_img:
|
||||
if has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_images.detach().to(self.device_torch, dtype=dtype),
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True
|
||||
)
|
||||
elif is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
adapter_img = torch.zeros(
|
||||
clip_images = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
).detach()
|
||||
# drop will zero it out
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_img,
|
||||
clip_images,
|
||||
drop=True,
|
||||
is_training=True
|
||||
)
|
||||
|
||||
@@ -421,6 +421,9 @@ class DatasetConfig:
|
||||
self.caption_type = self.caption_ext
|
||||
self.guidance_type: GuidanceType = kwargs.get('guidance_type', 'targeted')
|
||||
|
||||
# ip adapter / reference dataset
|
||||
self.clip_image_path: str = kwargs.get('clip_image_path', None) # depth maps, etc
|
||||
|
||||
|
||||
def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
|
||||
"""
|
||||
|
||||
@@ -8,7 +8,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from toolkit import image_utils
|
||||
from toolkit.dataloader_mixins import CaptionProcessingDTOMixin, ImageProcessingDTOMixin, LatentCachingFileItemDTOMixin, \
|
||||
ControlFileItemDTOMixin, ArgBreakMixin, PoiFileItemDTOMixin, MaskFileItemDTOMixin, AugmentationFileItemDTOMixin, \
|
||||
UnconditionalFileItemDTOMixin
|
||||
UnconditionalFileItemDTOMixin, ClipImageFileItemDTOMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
@@ -28,6 +28,7 @@ class FileItemDTO(
|
||||
CaptionProcessingDTOMixin,
|
||||
ImageProcessingDTOMixin,
|
||||
ControlFileItemDTOMixin,
|
||||
ClipImageFileItemDTOMixin,
|
||||
MaskFileItemDTOMixin,
|
||||
AugmentationFileItemDTOMixin,
|
||||
UnconditionalFileItemDTOMixin,
|
||||
@@ -71,6 +72,7 @@ class FileItemDTO(
|
||||
self.tensor = None
|
||||
self.cleanup_latent()
|
||||
self.cleanup_control()
|
||||
self.cleanup_clip_image()
|
||||
self.cleanup_mask()
|
||||
self.cleanup_unconditional()
|
||||
|
||||
@@ -83,6 +85,7 @@ class DataLoaderBatchDTO:
|
||||
self.tensor: Union[torch.Tensor, None] = None
|
||||
self.latents: Union[torch.Tensor, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
self.mask_tensor: Union[torch.Tensor, None] = None
|
||||
self.unaugmented_tensor: Union[torch.Tensor, None] = None
|
||||
self.unconditional_tensor: Union[torch.Tensor, None] = None
|
||||
@@ -113,6 +116,21 @@ class DataLoaderBatchDTO:
|
||||
control_tensors.append(x.control_tensor)
|
||||
self.control_tensor = torch.cat([x.unsqueeze(0) for x in control_tensors])
|
||||
|
||||
if any([x.clip_image_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_clip_image_tensor = None
|
||||
for x in self.file_items:
|
||||
if x.clip_image_tensor is not None:
|
||||
base_clip_image_tensor = x.clip_image_tensor
|
||||
break
|
||||
clip_image_tensors = []
|
||||
for x in self.file_items:
|
||||
if x.clip_image_tensor is None:
|
||||
clip_image_tensors.append(torch.zeros_like(base_clip_image_tensor))
|
||||
else:
|
||||
clip_image_tensors.append(x.clip_image_tensor)
|
||||
self.clip_image_tensor = torch.cat([x.unsqueeze(0) for x in clip_image_tensors])
|
||||
|
||||
if any([x.mask_tensor is not None for x in self.file_items]):
|
||||
# find one to use as a base
|
||||
base_mask_tensor = None
|
||||
|
||||
@@ -350,6 +350,8 @@ class ImageProcessingDTOMixin:
|
||||
self.get_latent()
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
self.load_mask_image()
|
||||
if self.has_unconditional:
|
||||
@@ -443,6 +445,8 @@ class ImageProcessingDTOMixin:
|
||||
if not only_load_latents:
|
||||
if self.has_control_image:
|
||||
self.load_control_image()
|
||||
if self.has_clip_image:
|
||||
self.load_clip_image()
|
||||
if self.has_mask_image:
|
||||
self.load_mask_image()
|
||||
if self.has_unconditional:
|
||||
@@ -523,6 +527,46 @@ class ControlFileItemDTOMixin:
|
||||
self.control_tensor = None
|
||||
|
||||
|
||||
class ClipImageFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.has_clip_image = False
|
||||
self.clip_image_path: Union[str, None] = None
|
||||
self.clip_image_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
if dataset_config.clip_image_path is not None:
|
||||
# find the control image path
|
||||
clip_image_path = dataset_config.clip_image_path
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
file_name_no_ext = os.path.splitext(os.path.basename(img_path))[0]
|
||||
for ext in img_ext_list:
|
||||
if os.path.exists(os.path.join(clip_image_path, file_name_no_ext + ext)):
|
||||
self.clip_image_path = os.path.join(clip_image_path, file_name_no_ext + ext)
|
||||
self.has_clip_image = True
|
||||
break
|
||||
|
||||
def load_clip_image(self: 'FileItemDTO'):
|
||||
img = Image.open(self.clip_image_path).convert('RGB')
|
||||
try:
|
||||
img = exif_transpose(img)
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.clip_image_path}")
|
||||
|
||||
# we just scale them to 512x512:
|
||||
img = img.resize((512, 512), Image.BICUBIC)
|
||||
|
||||
self.clip_image_tensor = transforms.ToTensor()(img)
|
||||
|
||||
def cleanup_clip_image(self: 'FileItemDTO'):
|
||||
self.clip_image_tensor = None
|
||||
|
||||
|
||||
|
||||
|
||||
class AugmentationFileItemDTOMixin:
|
||||
def __init__(self: 'FileItemDTO', *args, **kwargs):
|
||||
if hasattr(super(), '__init__'):
|
||||
|
||||
@@ -153,7 +153,10 @@ class IPAdapter(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
except EnvironmentError:
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
|
||||
ignore_mismatched_sizes=True)
|
||||
|
||||
Reference in New Issue
Block a user