Fixed ip adapter training. Works now

This commit is contained in:
Jaret Burkett
2023-12-17 08:22:59 -07:00
parent 13d32423f6
commit b653906715
7 changed files with 102 additions and 40 deletions

View File

@@ -6,6 +6,7 @@ from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
from toolkit.saving import load_ip_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
@@ -21,6 +22,16 @@ import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from diffusers.models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
@@ -89,6 +100,16 @@ class IPAdapter(torch.nn.Module):
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
# load the weights if we have some
if self.config.name_or_path:
loaded_state_dict = load_ip_adapter_model(
self.config.name_or_path,
device='cpu',
dtype=sd.torch_dtype
)
self.load_state_dict(loaded_state_dict)
self.set_scale(1.0)
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
@@ -102,6 +123,9 @@ class IPAdapter(torch.nn.Module):
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
state_dict["image_proj"] = self.image_proj_model.state_dict()
@@ -109,7 +133,7 @@ class IPAdapter(torch.nn.Module):
return state_dict
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@@ -131,9 +155,21 @@ class IPAdapter(torch.nn.Module):
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
@@ -155,3 +191,4 @@ class IPAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)