Fixed ip adapter training. Works now

This commit is contained in:
Jaret Burkett
2023-12-17 08:22:59 -07:00
parent 13d32423f6
commit b653906715
7 changed files with 102 additions and 40 deletions

View File

@@ -549,11 +549,13 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.stop('preprocess_batch')
is_reg = False
with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
is_reg = True
adapter_images = None
sigmas = None
@@ -764,11 +766,27 @@ class SDTrainer(BaseSDTrainProcess):
batch=batch,
)
if has_adapter_img and self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter'):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
if has_adapter_img:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_images.detach().to(self.device_torch, dtype=dtype))
elif is_reg:
# we will zero it out in the img embedder
adapter_img = torch.zeros(
(noisy_latents.shape[0], 3, 512, 512),
device=self.device_torch, dtype=dtype
)
# drop will zero it out
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
adapter_img, drop=True
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
with self.timer('encode_adapter'):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
self.before_unet_predict()
# do a prior pred if we have an unconditional image, we will swap out the giadance later

View File

@@ -880,13 +880,15 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
else:
# ip adapter
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.adapter.load_state_dict(loaded_state_dict)
if self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
# set trainable params

View File

@@ -375,6 +375,8 @@ class DatasetConfig:
self.flip_y: bool = kwargs.get('flip_y', False)
self.augments: List[str] = kwargs.get('augments', [])
self.control_path: str = kwargs.get('control_path', None) # depth maps, etc
# instead of cropping ot match image, it will serve the full size control image (clip images ie for ip adapters)
self.full_size_control_images: bool = kwargs.get('full_size_control_images', False)
self.alpha_mask: bool = kwargs.get('alpha_mask', False) # if true, will use alpha channel as mask
self.mask_path: str = kwargs.get('mask_path',
None) # focus mask (black and white. White has higher loss than black)

View File

@@ -457,9 +457,11 @@ class ControlFileItemDTOMixin:
self.control_path: Union[str, None] = None
self.control_tensor: Union[torch.Tensor, None] = None
dataset_config: 'DatasetConfig' = kwargs.get('dataset_config', None)
self.full_size_control_images = False
if dataset_config.control_path is not None:
# find the control image path
control_path = dataset_config.control_path
self.full_size_control_images = dataset_config.full_size_control_images
# we are using control images
img_path = kwargs.get('path', None)
img_ext_list = ['.jpg', '.jpeg', '.png', '.webp']
@@ -477,36 +479,38 @@ class ControlFileItemDTOMixin:
except Exception as e:
print(f"Error: {e}")
print(f"Error loading image: {self.control_path}")
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
if not self.full_size_control_images:
w, h = img.size
if w > h and self.scale_to_width < self.scale_to_height:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
elif h > w and self.scale_to_height < self.scale_to_width:
# throw error, they should match
raise ValueError(
f"unexpected values: w={w}, h={h}, file_item.scale_to_width={self.scale_to_width}, file_item.scale_to_height={self.scale_to_height}, file_item.path={self.path}")
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
if self.flip_x:
# do a flip
img.transpose(Image.FLIP_LEFT_RIGHT)
if self.flip_y:
# do a flip
img.transpose(Image.FLIP_TOP_BOTTOM)
if self.dataset_config.buckets:
# scale and crop based on file item
img = img.resize((self.scale_to_width, self.scale_to_height), Image.BICUBIC)
# img = transforms.CenterCrop((self.crop_height, self.crop_width))(img)
# crop
img = img.crop((
self.crop_x,
self.crop_y,
self.crop_x + self.crop_width,
self.crop_y + self.crop_height
))
else:
raise Exception("Control images not supported for non-bucket datasets")
self.control_tensor = transforms.ToTensor()(img)

View File

@@ -6,6 +6,7 @@ from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
@@ -21,6 +22,16 @@ import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from diffusers.models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
@@ -89,6 +100,16 @@ class IPAdapter(torch.nn.Module):
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
# load the weights if we have some
if self.config.name_or_path:
loaded_state_dict = load_ip_adapter_model(
self.config.name_or_path,
device='cpu',
dtype=sd.torch_dtype
)
self.load_state_dict(loaded_state_dict)
self.set_scale(1.0)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
@@ -102,6 +123,9 @@ class IPAdapter(torch.nn.Module):
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
state_dict["image_proj"] = self.image_proj_model.state_dict()
@@ -109,7 +133,7 @@ class IPAdapter(torch.nn.Module):
return state_dict
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@@ -131,9 +155,21 @@ class IPAdapter(torch.nn.Module):
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
@@ -155,3 +191,4 @@ class IPAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)

View File

@@ -474,9 +474,7 @@ class StableDiffusion:
extra['adapter_conditioning_scale'] = gen_config.adapter_conditioning_scale
if isinstance(self.adapter, IPAdapter):
transform = transforms.Compose([
transforms.Resize(gen_config.width,
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.PILToTensor(),
transforms.ToTensor(),
])
validation_image = transform(validation_image)
@@ -500,6 +498,7 @@ class StableDiffusion:
if self.adapter is not None and isinstance(self.adapter,
IPAdapter) and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image,