mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-20 12:23:57 +00:00
Fixed ip adapter training. Works now
This commit is contained in:
@@ -549,11 +549,13 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
is_reg = False
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||
is_reg = True
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
@@ -764,11 +766,27 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter'):
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
if has_adapter_img:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_images.detach().to(self.device_torch, dtype=dtype))
|
||||
elif is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
adapter_img = torch.zeros(
|
||||
(noisy_latents.shape[0], 3, 512, 512),
|
||||
device=self.device_torch, dtype=dtype
|
||||
)
|
||||
# drop will zero it out
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
adapter_img, drop=True
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
|
||||
with self.timer('encode_adapter'):
|
||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
|
||||
|
||||
self.before_unet_predict()
|
||||
# do a prior pred if we have an unconditional image, we will swap out the giadance later
|
||||
|
||||
@@ -880,13 +880,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
else:
|
||||
# ip adapter
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
if self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
|
||||
Submodule repositories/ipadapter updated: d8ab37c421...f71c943b7e
@@ -375,6 +375,8 @@ class DatasetConfig:
|
||||
self.flip_y: bool = kwargs.get('flip_y', False)
|
||||
self.augments: List[str] = kwargs.get('augments', [])
|
||||
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
|
||||
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
|
||||
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
|
||||
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
|
||||
self.mask_path: str = kwargs.get('mask_path',
|
||||
None) # focus mask (black and white. White has higher loss than black)
|
||||
|
||||
@@ -457,9 +457,11 @@ class ControlFileItemDTOMixin:
|
||||
self.control_path: Union[str, None] = None
|
||||
self.control_tensor: Union[torch.Tensor, None] = None
|
||||
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
|
||||
self.full_size_control_images = False
|
||||
if dataset_config.control_path is not None:
|
||||
# find the control image path
|
||||
control_path = dataset_config.control_path
|
||||
self.full_size_control_images = dataset_config.full_size_control_images
|
||||
# we are using control images
|
||||
img_path = kwargs.get('path', None)
|
||||
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
|
||||
@@ -477,36 +479,38 @@ class ControlFileItemDTOMixin:
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Error loading image: {self.control_path}")
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
if not self.full_size_control_images:
|
||||
w, h = img.size
|
||||
if w > h and self.scale_to_width < self.scale_to_height:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
elif h > w and self.scale_to_height < self.scale_to_width:
|
||||
# throw error, they should match
|
||||
raise ValueError(
|
||||
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
if self.flip_x:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if self.flip_y:
|
||||
# do a flip
|
||||
img.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
|
||||
if self.dataset_config.buckets:
|
||||
# scale and crop based on file item
|
||||
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
|
||||
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
|
||||
# crop
|
||||
img = img.crop((
|
||||
self.crop_x,
|
||||
self.crop_y,
|
||||
self.crop_x + self.crop_width,
|
||||
self.crop_y + self.crop_height
|
||||
))
|
||||
else:
|
||||
raise Exception("Control images not supported for non-bucket datasets")
|
||||
|
||||
self.control_tensor = transforms.ToTensor()(img)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
@@ -21,6 +22,16 @@ import weakref
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers.models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
class IPAdapter(torch.nn.Module):
|
||||
@@ -89,6 +100,16 @@ class IPAdapter(torch.nn.Module):
|
||||
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
# load the weights if we have some
|
||||
if self.config.name_or_path:
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
self.config.name_or_path,
|
||||
device='cpu',
|
||||
dtype=sd.torch_dtype
|
||||
)
|
||||
self.load_state_dict(loaded_state_dict)
|
||||
|
||||
self.set_scale(1.0)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
@@ -102,6 +123,9 @@ class IPAdapter(torch.nn.Module):
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.load_ip_adapter(state_dict)
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||
@@ -109,7 +133,7 @@ class IPAdapter(torch.nn.Module):
|
||||
return state_dict
|
||||
|
||||
def set_scale(self, scale):
|
||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@@ -131,9 +155,21 @@ class IPAdapter(torch.nn.Module):
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
@@ -155,3 +191,4 @@ class IPAdapter(torch.nn.Module):
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
|
||||
|
||||
@@ -474,9 +474,7 @@ class StableDiffusion:
|
||||
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(gen_config.width,
|
||||
interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.PILToTensor(),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
|
||||
@@ -500,6 +498,7 @@ class StableDiffusion:
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter,
|
||||
IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,
|
||||
|
||||
Reference in New Issue
Block a user