mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 13:23:56 +00:00
Added IP adapter training. Not functioning correctly yet
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -7,3 +7,6 @@
|
||||
[submodule "repositories/batch_annotator"]
|
||||
path = repositories/batch_annotator
|
||||
url = https://github.com/ostris/batch-annotator
|
||||
[submodule "repositories/ipadapter"]
|
||||
path = repositories/ipadapter
|
||||
url = https://github.com/tencent-ailab/IP-Adapter.git
|
||||
|
||||
@@ -2,9 +2,11 @@ import os.path
|
||||
from collections import OrderedDict
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import T2IAdapter
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
|
||||
@@ -115,13 +117,19 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
conditional_embeds = conditional_embeds.detach()
|
||||
# flush()
|
||||
pred_kwargs = {}
|
||||
if self.adapter:
|
||||
if self.adapter and isinstance(self.adapter, T2IAdapter):
|
||||
down_block_additional_residuals = self.adapter(adapter_images)
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_additional_residuals
|
||||
]
|
||||
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with torch.no_grad():
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
|
||||
|
||||
noise_pred = self.sd.predict_noise(
|
||||
latents=noisy_latents.to(self.device_torch, dtype=dtype),
|
||||
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
|
||||
@@ -164,6 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
|
||||
|
||||
loss = loss.mean()
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
|
||||
@@ -15,6 +15,7 @@ from toolkit.basic import value_map
|
||||
from toolkit.data_loader import get_dataloader_from_datasets
|
||||
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
|
||||
from toolkit.embedding import Embedding
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.lora_special import LoRASpecialNetwork
|
||||
from toolkit.lycoris_special import LycorisSpecialNetwork
|
||||
from toolkit.network_mixins import Network
|
||||
@@ -22,7 +23,8 @@ from toolkit.optimizer import get_optimizer
|
||||
from toolkit.paths import CONFIG_ROOT
|
||||
from toolkit.progress_bar import ToolkitProgressBar
|
||||
from toolkit.sampler import get_sampler
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
|
||||
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
|
||||
load_ip_adapter_model
|
||||
|
||||
from toolkit.scheduler import get_lr_scheduler
|
||||
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
|
||||
@@ -118,9 +120,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
# to hold network if there is one
|
||||
self.network: Union[Network, None] = None
|
||||
self.adapter: Union[T2IAdapter, None] = None
|
||||
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
|
||||
self.embedding: Union[Embedding, None] = None
|
||||
|
||||
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
|
||||
|
||||
# get the device state preset based on what we are training
|
||||
self.train_device_state_preset = get_train_sd_device_state_preset(
|
||||
device=self.device_torch,
|
||||
@@ -128,17 +132,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
train_text_encoder=self.train_config.train_text_encoder,
|
||||
cached_latents=self.is_latents_cached,
|
||||
train_lora=self.network_config is not None,
|
||||
train_adapter=self.adapter_config is not None,
|
||||
train_adapter=is_training_adapter,
|
||||
train_embedding=self.embed_config is not None,
|
||||
)
|
||||
|
||||
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
|
||||
self.is_fine_tuning = True
|
||||
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
|
||||
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
|
||||
self.is_fine_tuning = False
|
||||
|
||||
self.named_lora = False
|
||||
if self.embed_config is not None or self.adapter_config is not None:
|
||||
if self.embed_config is not None or is_training_adapter:
|
||||
self.named_lora = True
|
||||
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
|
||||
# override in subclass
|
||||
@@ -179,7 +183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
)
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter_config is not None:
|
||||
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
|
||||
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
|
||||
|
||||
gen_img_config_list.append(GenerateImageConfig(
|
||||
@@ -318,22 +322,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
|
||||
self.embedding.save(emb_file_path)
|
||||
|
||||
if self.adapter is not None:
|
||||
if self.adapter is not None and self.adapter_config.train:
|
||||
adapter_name = self.job.name
|
||||
if self.network_config is not None or self.embedding is not None:
|
||||
# add _lora to name
|
||||
adapter_name += '_t2i'
|
||||
if self.adapter_config.type == 't2i':
|
||||
adapter_name += '_t2i'
|
||||
else:
|
||||
adapter_name += '_ip'
|
||||
|
||||
filename = f'{adapter_name}{step_num}.safetensors'
|
||||
file_path = os.path.join(self.save_root, filename)
|
||||
# save adapter
|
||||
state_dict = self.adapter.state_dict()
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
if self.adapter_config.type == 't2i':
|
||||
save_t2i_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
save_ip_adapter_from_diffusers(
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
)
|
||||
else:
|
||||
self.sd.save(
|
||||
file_path,
|
||||
@@ -527,6 +542,50 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
|
||||
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
|
||||
|
||||
def setup_adapter(self):
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
is_t2i = self.adapter_config.type == 't2i'
|
||||
if is_t2i:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
else:
|
||||
self.adapter = IPAdapter(
|
||||
sd=self.sd,
|
||||
adapter_config=self.adapter_config,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
suffix = 't2i' if is_t2i else 'ip'
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_{suffix}"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
if is_t2i:
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
else:
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
if self.adapter_config.train:
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
# set trainable params
|
||||
self.sd.adapter = self.adapter
|
||||
|
||||
def run(self):
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
# run base process run
|
||||
@@ -741,35 +800,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
self.adapter = T2IAdapter(
|
||||
in_channels=self.adapter_config.in_channels,
|
||||
channels=self.adapter_config.channels,
|
||||
num_res_blocks=self.adapter_config.num_res_blocks,
|
||||
downscale_factor=self.adapter_config.downscale_factor,
|
||||
adapter_type=self.adapter_config.adapter_type,
|
||||
)
|
||||
self.adapter.to(self.device_torch, dtype=dtype)
|
||||
# t2i adapter
|
||||
adapter_name = self.name
|
||||
if self.network_config is not None:
|
||||
adapter_name = f"{adapter_name}_t2i"
|
||||
latest_save_path = self.get_latest_save_path(adapter_name)
|
||||
if latest_save_path is not None:
|
||||
# load adapter from path
|
||||
print(f"Loading adapter from {latest_save_path}")
|
||||
loaded_state_dict = load_t2i_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
self.load_training_state_from_metadata(latest_save_path)
|
||||
self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
self.sd.adapter = self.adapter
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
@@ -785,6 +821,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
unet_lr=self.train_config.lr,
|
||||
default_lr=self.train_config.lr
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
### HOOK ###
|
||||
params = self.hook_add_extra_train_params(params)
|
||||
|
||||
1
repositories/ipadapter
Submodule
1
repositories/ipadapter
Submodule
Submodule repositories/ipadapter added at d8ab37c421
@@ -61,8 +61,11 @@ class NetworkConfig:
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
|
||||
|
||||
AdapterTypes = Literal['t2i', 'ip', 'ip+']
|
||||
|
||||
class AdapterConfig:
|
||||
def __init__(self, **kwargs):
|
||||
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
|
||||
self.in_channels: int = kwargs.get('in_channels', 3)
|
||||
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
|
||||
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
|
||||
@@ -70,6 +73,9 @@ class AdapterConfig:
|
||||
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
|
||||
self.image_dir: str = kwargs.get('image_dir', None)
|
||||
self.test_img_path: str = kwargs.get('test_img_path', None)
|
||||
self.train: str = kwargs.get('train', False)
|
||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||
self.name_or_path = kwargs.get('name_or_path', None)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
|
||||
157
toolkit/ip_adapter.py
Normal file
157
toolkit/ip_adapter.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
import weakref
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
|
||||
super().__init__()
|
||||
self.config = adapter_config
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=sd.unet.config['cross_attention_dim'],
|
||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=4,
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
# ip-adapter-plus
|
||||
num_tokens = 16
|
||||
image_proj_model = Resampler(
|
||||
dim=sd.unet.config['cross_attention_dim'],
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=num_tokens,
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
ff_mult=4
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {adapter_config.type}")
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
sd.adapter = self
|
||||
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
return state_dict
|
||||
|
||||
def set_scale(self, scale):
|
||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
|
||||
# todo: add support for sdxl
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor:
|
||||
# tensors should be 0-1
|
||||
# todo: add support for sdxl
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
return clip_image_embeds
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach())
|
||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||
return embeddings
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
for attn_processor in self.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
@@ -306,6 +306,7 @@ class ToolkitNetworkMixin:
|
||||
extra_dict = None
|
||||
return extra_dict
|
||||
|
||||
@torch.no_grad()
|
||||
def _update_torch_multiplier(self: Network):
|
||||
# builds a tensor for fast usage in the forward pass of the network modules
|
||||
# without having to set it in every single module every time it changes
|
||||
|
||||
@@ -193,3 +193,42 @@ def load_t2i_model(
|
||||
# todo see if we need to convert dict
|
||||
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
|
||||
|
||||
def save_ip_adapter_from_diffusers(
|
||||
combined_state_dict: 'OrderedDict',
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
converted_state_dict = OrderedDict()
|
||||
for module_name, state_dict in combined_state_dict.items():
|
||||
for key, value in state_dict.items():
|
||||
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
save_file(converted_state_dict, output_file, metadata=meta)
|
||||
|
||||
|
||||
def load_ip_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
if module_name not in combined_state_dict:
|
||||
combined_state_dict[module_name] = OrderedDict()
|
||||
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
|
||||
return combined_state_dict
|
||||
else:
|
||||
return torch.load(path_to_file, map_location=device)
|
||||
|
||||
@@ -13,8 +13,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
|
||||
from safetensors.torch import save_file, load_file
|
||||
from torch.nn import Parameter
|
||||
from tqdm import tqdm
|
||||
from torchvision.transforms import Resize
|
||||
from torchvision.transforms import Resize, transforms
|
||||
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
|
||||
convert_vae_state_dict, load_vae
|
||||
from toolkit import train_tools
|
||||
@@ -123,7 +124,7 @@ class StableDiffusion:
|
||||
|
||||
# to hold network if there is one
|
||||
self.network = None
|
||||
self.adapter: Union['T2IAdapter', None] = None
|
||||
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
|
||||
self.is_xl = model_config.is_xl
|
||||
self.is_v2 = model_config.is_v2
|
||||
|
||||
@@ -302,12 +303,13 @@ class StableDiffusion:
|
||||
Pipe = StableDiffusionPipeline
|
||||
|
||||
extra_args = {}
|
||||
if self.adapter:
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
if self.is_xl:
|
||||
Pipe = StableDiffusionXLAdapterPipeline
|
||||
else:
|
||||
Pipe = StableDiffusionAdapterPipeline
|
||||
extra_args['adapter'] = self.adapter
|
||||
|
||||
# TODO add clip skip
|
||||
if self.is_xl:
|
||||
@@ -358,11 +360,19 @@ class StableDiffusion:
|
||||
gen_config = image_configs[i]
|
||||
|
||||
extra = {}
|
||||
if gen_config.adapter_image_path is not None:
|
||||
if self.adapter is not None and gen_config.adapter_image_path is not None:
|
||||
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
if isinstance(self.adapter, T2IAdapter):
|
||||
# not sure why this is double??
|
||||
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
|
||||
extra['image'] = validation_image
|
||||
extra['adapter_conditioning_scale'] = 1.0
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.PILToTensor(),
|
||||
])
|
||||
validation_image = transform(validation_image)
|
||||
|
||||
if self.network is not None:
|
||||
self.network.multiplier = gen_config.network_multiplier
|
||||
@@ -382,6 +392,14 @@ class StableDiffusion:
|
||||
unconditional_embeds,
|
||||
)
|
||||
|
||||
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
|
||||
# apply the image projection
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
|
||||
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
|
||||
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
|
||||
|
||||
|
||||
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
|
||||
if self.is_xl:
|
||||
# fix guidance rescale for sdxl
|
||||
@@ -931,10 +949,18 @@ class StableDiffusion:
|
||||
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
|
||||
}
|
||||
if self.adapter is not None:
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
requires_grad = self.adapter.adapter_modules.training
|
||||
adapter_device = self.unet.device
|
||||
elif isinstance(self.adapter, T2IAdapter):
|
||||
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
|
||||
adapter_device = self.adapter.device
|
||||
else:
|
||||
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
|
||||
self.device_state['adapter'] = {
|
||||
'training': self.adapter.training,
|
||||
'device': self.adapter.device,
|
||||
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
|
||||
'device': adapter_device,
|
||||
'requires_grad': requires_grad,
|
||||
}
|
||||
|
||||
def restore_device_state(self):
|
||||
|
||||
Reference in New Issue
Block a user